In [1]:
import torch, os, evaluate, sys
import torch.nn as nn
import pandas as pd

import dataclasses, sys, pickle, json
import torch.nn.functional as F

from typing import List, Optional

import matplotlib.pyplot as plt, random, numpy as np

from datasets import load_dataset,Dataset,DatasetDict
from transformers import DataCollatorWithPadding,AutoModelForSequenceClassification, Trainer, TrainingArguments,AutoTokenizer,AutoModel,AutoConfig
from transformers.modeling_outputs import TokenClassifierOutput
from torch.utils.data import DataLoader
from transformers import AdamW,get_scheduler
from datasets import load_metric
from tqdm.auto import tqdm

from torch.nn import DataParallel

sys.path.append("/home/pritam.k/research/data-moe")


# from src.utils.helper import CustomModel, ModelArgs, MoeArgs
from src.utils.helper import ConfiguredMetric

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pwd

'/home/pritam.k/research/data-moe/notebooks'

In [3]:

seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)


def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=128, padding=True)

os.environ["HF_HOME"]="/home/pritam.k/research/huggingface"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
checkpoint = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


with open("../results/best_hyperparameters_tweet_eval.json", "r") as jfile:
    hp_params=json.load(jfile)

lr=hp_params['learning_rate']
batch_size=hp_params['per_device_train_batch_size']
num_epochs=hp_params['num_train_epochs']
num_experts=3
num_experts_per_tok=1
mode=f"{num_experts}_{num_experts_per_tok}_test" #ce/cbz/cb/cz/

In [4]:
# # Example tensor for difficulty levels
# # Difficulty mapping: 0 for easy, 1 for ambiguous, 2 for hard
# difficulty = torch.tensor([0 if instance in easy_ids else 1 if instance in ambi_ids else 2 for instance in hard_ids], device=device)

# # Pass the difficulty levels to the model during forward pass
# output, expert_tracking = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels, difficulty_levels=difficulty_levels)


In [5]:

ds_train = load_dataset('csv', data_files='../data/tweet_eval/updated/train.csv')
ds_val = load_dataset('csv', data_files='../data/tweet_eval/updated/val.csv')
ds_test = load_dataset('csv', data_files='../data/tweet_eval/updated/test.csv')

#print(ds)
#sys.exit()
data = DatasetDict({
    'train': ds_train['train'],
    'valid': ds_val['train'],
    'test': ds_test['train']}  
)

print(data)
#sys.exit()
tokenized_dataset = data.map(preprocess_function, batched=True, num_proc=12)
tokenized_dataset = tokenized_dataset.remove_columns(["text", "split"])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset = tokenized_dataset.rename_column("idx", "global_index")
#tokenized_dataset.set_format("torch")
tokenized_dataset.set_format("torch",columns=["global_index","input_ids", "attention_mask", "labels", "difficulty"])
print(tokenized_dataset["train"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


train_dataloader = DataLoader(
    tokenized_dataset["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_dataset["valid"], batch_size=batch_size, collate_fn=data_collator
)


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'split', 'idx', 'difficulty'],
        num_rows: 45615
    })
    valid: Dataset({
        features: ['text', 'label', 'split', 'idx', 'difficulty'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label', 'split', 'idx', 'difficulty'],
        num_rows: 12284
    })
})
Dataset({
    features: ['labels', 'global_index', 'difficulty', 'input_ids', 'attention_mask'],
    num_rows: 45615
})


In [6]:
def z_loss(logits):
    log_sum_exp=torch.logsumexp(logits,dim=1)
    sq_log_sum_exp=log_sum_exp=log_sum_exp**2
    loss=torch.mean(sq_log_sum_exp)
    return loss

def b_loss(logits,alpha=1e-2):
    T,N=logits.shape
    probs=torch.softmax(logits,dim=1)

    argmax_experts=torch.argmax(probs,dim=1)
    f=torch.zeros(N,device=logits.device)
    for i in range(N):
        f[i]=(argmax_experts==i).float().mean()
    P=probs.mean(dim=0)

    loss=alpha*N*torch.sum(f*P)
    return loss

In [7]:
@dataclasses.dataclass
class MoeArgs:
    def __init__(self, num_experts: int,num_experts_per_tok):
        self.num_experts=num_experts
        self.num_experts_per_tok=num_experts_per_tok

# class MoeLayer(nn.Module):
#     def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
#         super().__init__()
#         assert len(experts) > 0
#         self.gate = gate
#         self.experts = nn.ModuleList(experts)
#         self.args = moe_args
#         #self.expert_assignments = None

#     def forward(self, inputs: torch.Tensor):
#         gate_logits = self.gate(inputs)
#         loss_z=z_loss(gate_logits)
#         loss_b=b_loss(gate_logits)
#         weights,selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
#         weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
#         results = torch.zeros_like(inputs)
#         expert_tracking = {i: [] for i in range(self.args.num_experts)}
#         for current_expert_index, current_expert in enumerate(self.experts):
#             token_index, token_expert_index = torch.where(selected_experts == current_expert_index)
#             for idx in token_index.cpu().numpy():
#                 expert_tracking[current_expert_index].append(idx)
#             results[token_index] += weights[token_index, token_expert_index, None] * current_expert(
#                 inputs[token_index]
#             )
#         return results,loss_z,loss_b,expert_tracking


class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
        super().__init__()
        assert len(experts) > 0
        self.gate = gate
        self.experts = nn.ModuleList(experts)
        self.args = moe_args

    def forward(self, inputs: torch.Tensor, expert_assignments=None):
        gate_logits = self.gate(inputs)
        loss_z = z_loss(gate_logits)
        loss_b = b_loss(gate_logits)

        # If expert_assignments is provided, override the gating mechanism
        if expert_assignments is not None:
            selected_experts = expert_assignments
            weights = torch.ones(inputs.size(0), self.args.num_experts_per_tok, dtype=inputs.dtype, device=inputs.device)
        else:
            weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
            weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
        
        results = torch.zeros_like(inputs)
        expert_tracking = {i: [] for i in range(self.args.num_experts)}
        if expert_assignments is not None:
            for current_expert_index, current_expert in enumerate(self.experts):
                token_index = torch.where(selected_experts == current_expert_index)[0]  # Get the indices of tokens assigned to the current expert
                for idx in token_index.cpu().numpy():
                    expert_tracking[current_expert_index].append(idx)

                # Correcting the shape by removing unsqueeze(1) and directly multiplying
                results[token_index] += weights[token_index].unsqueeze(-1) * current_expert(inputs[token_index])

        else:
            for current_expert_index, current_expert in enumerate(self.experts):
                token_index, token_expert_index = torch.where(selected_experts == current_expert_index)
                for idx in token_index.cpu().numpy():
                    expert_tracking[current_expert_index].append(idx)
                results[token_index] += weights[token_index, token_expert_index, None] * current_expert(
                    inputs[token_index]
                )
        
        return results, loss_z, loss_b, expert_tracking



class ModelArgs:
    def __init__(self, dim: int, hidden_dim: int, num_labels: int, moe: MoeArgs):
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.num_labels = num_labels
        self.moe = moe


class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
        self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)

    def forward(self, x) -> torch.Tensor:
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))


In [8]:
# class CustomModel(nn.Module):
#     def __init__(self,checkpoint,args: ModelArgs):
#         super(CustomModel,self).__init__()

#         self.model = model = AutoModel.from_pretrained(checkpoint,config=AutoConfig.from_pretrained(checkpoint, output_attentions=True,output_hidden_states=True))
#         for param in self.model.parameters():
#             param.requires_grad = False
#         self.args=args
#         self.moe_layer=MoeLayer(experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),moe_args=args.moe,)
#         self.classifier = nn.Linear(768,args.num_labels)

#     def forward(self, input_ids=None, attention_mask=None,labels=None):
#         outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
#         moe_outputs,loss_z,loss_b,expert_tracking = self.moe_layer(outputs[0][:,0,:].view(-1,768))
#         #expert_assignments = self.moe_layer.expert_assignments
#         #print(experts)
#         logits=self.classifier(moe_outputs)

#         loss = None
#         if labels is not None:
#             loss_fct = nn.CrossEntropyLoss()
#             loss = loss_fct(logits.view(-1, self.args.num_labels), labels.view(-1))
#             loss=loss+loss_b+loss_z

#         return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states,attentions=outputs.attentions),expert_tracking



class CustomModel(nn.Module):
    def __init__(self, checkpoint, args: ModelArgs):
        super(CustomModel, self).__init__()

        self.model = AutoModel.from_pretrained(checkpoint, config=AutoConfig.from_pretrained(checkpoint, output_attentions=True, output_hidden_states=True))
        self.args = args
        self.moe_layer = MoeLayer(experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
                                  gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
                                  moe_args=args.moe)
        self.classifier = nn.Linear(768, args.num_labels)

    def forward(self, input_ids=None, attention_mask=None, labels=None, difficulty=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        # Map difficulty levels to expert indices
        if difficulty is not None:
            # Map difficulties: easy -> 0, ambiguous -> 1, hard -> 2
            expert_assignments = difficulty
        else:
            expert_assignments = None

        moe_outputs, loss_z, loss_b, expert_tracking = self.moe_layer(outputs[0][:, 0, :].view(-1, 768), expert_assignments=expert_assignments)
        logits = self.classifier(moe_outputs)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.args.num_labels), labels.view(-1))
            loss = loss + loss_b + loss_z

        return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions), expert_tracking


In [9]:
args=ModelArgs(dim=768,hidden_dim=3072,num_labels=3,moe=MoeArgs(num_experts,num_experts_per_tok))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=CustomModel(checkpoint=checkpoint,args=args).to(device)
# model.to(device)
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {total_params}')

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total number of parameters: 145883907


In [10]:
# df_train = pd.read_csv("../data/tweet_eval/updated/train.csv")


# df_train_easy = df_train[df_train['difficulty'] == 'easy']
# df_train_ambi = df_train[df_train['difficulty'] == 'ambi']
# df_train_hard = df_train[df_train['difficulty'] == 'hard']

# easy_ids=df_train_easy.idx.values.tolist()
# ambi_ids=df_train_ambi.idx.values.tolist()
# hard_ids=df_train_hard.idx.values.tolist()

In [11]:
# # Example tensor for difficulty levels
# # Difficulty mapping: 0 for easy, 1 for ambiguous, 2 for hard
# difficulty_levels = torch.tensor([0 if instance in easy_ids else 1 if instance in ambi_ids else 2 for instance in hard_ids], device=device)

# # Pass the difficulty levels to the model during forward pass
# output, expert_tracking = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels, difficulty_levels=difficulty_levels)


In [12]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

num_epochs = num_epochs
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

metric = evaluate.combine([
                evaluate.load('accuracy'), 
                ConfiguredMetric(evaluate.load('f1'), average='macro'),
                ConfiguredMetric(evaluate.load('precision'), average='macro'),
                ConfiguredMetric(evaluate.load('recall'), average='macro'),
            ])

1785


In [13]:

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)))

train_losses = []  
val_losses = []   

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    num_batches = 0
    for batch in train_dataloader:
        e = 'global_index'
        batch = {k: v.to(device) for k, v in batch.items() if k != e}
        outputs, selected_experts = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()  # Accumulate the loss
        num_batches += 1
        
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar_train.update(1)

    
    avg_train_loss = total_loss / num_batches
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch+1}, Training loss: {avg_train_loss}")

    # Evaluation step
    model.eval()
    total_val_loss = 0
    num_val_batches = 0
    for batch in eval_dataloader:
        e = 'global_index'
        batch = {k: v.to(device) for k, v in batch.items() if k != e}
        with torch.no_grad():
            outputs, sel = model(**batch)
        
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
        
        # Calculate validation loss
        val_loss = outputs.loss.item()  # Assuming outputs has a loss attribute
        total_val_loss += val_loss
        num_val_batches += 1
        progress_bar_eval.update(1)

    # Calculate the average validation loss for this epoch and store it
    avg_val_loss = total_val_loss / num_val_batches
    val_losses.append(avg_val_loss)
    print(f"Epoch {epoch+1}, Validation loss: {avg_val_loss}")
    print(metric.compute())


  0%|          | 0/1785 [00:00<?, ?it/s]

RuntimeError: output with shape [86, 768] doesn't match the broadcast shape [86, 86, 768]

In [14]:

metric_test = evaluate.combine([
                evaluate.load('accuracy'), 
                ConfiguredMetric(evaluate.load('f1'), average='macro'),
                ConfiguredMetric(evaluate.load('precision'), average='macro'),
                ConfiguredMetric(evaluate.load('recall'), average='macro'),
            ])
test_dataloader = DataLoader(
    tokenized_dataset["test"], batch_size=batch_size, collate_fn=data_collator
)
expert_list=[]
model.eval()
for batch in test_dataloader:
    key_to_exclude = 'global_index'
    glo_batch = {k: v for k, v in batch.items() if k == key_to_exclude}
    batch = {k: v.to(device) for k, v in batch.items() if k != key_to_exclude}
    with torch.no_grad():
        outputs, s = model(**batch)
        for k in s.keys():
            for i in range(len(s[k])):
                index=s[k][i]
                s[k][i]=glo_batch['global_index'][index].item()
        expert_list.append(s)

    logits = outputs.logits
    #print(f"test: {s}")
    predictions = torch.argmax(logits, dim=-1)
    metric_test.add_batch(predictions=predictions, references=batch["labels"])


# with open("results/expert_dist.pkl", "wb") as pfile:
#     pickle.dump(expert_list, pfile)
results = metric_test.compute()
print(results)

{'accuracy': 0.5346792575708238, 'f1': 0.3988003065228942, 'precision': 0.3878664292991636, 'recall': 0.45945052288798466}


  _warn_prf(average, modifier, msg_start, len(result))


In [15]:
expert_list[0]

{0: [],
 1: [],
 2: [47615,
  47616,
  47617,
  47618,
  47619,
  47620,
  47621,
  47622,
  47623,
  47624,
  47625,
  47626,
  47627,
  47628,
  47629,
  47630,
  47631,
  47632,
  47633,
  47634,
  47635,
  47636,
  47637,
  47638,
  47639,
  47640,
  47641,
  47642,
  47643,
  47644,
  47645,
  47646,
  47647,
  47648,
  47649,
  47650,
  47651,
  47652,
  47653,
  47654,
  47655,
  47656,
  47657,
  47658,
  47659,
  47660,
  47661,
  47662,
  47663,
  47664,
  47665,
  47666,
  47667,
  47668,
  47669,
  47670,
  47671,
  47672,
  47673,
  47674,
  47675,
  47676,
  47677,
  47678,
  47679,
  47680,
  47681,
  47682,
  47683,
  47684,
  47685,
  47686,
  47687,
  47688,
  47689,
  47690,
  47691,
  47692,
  47693,
  47694,
  47695,
  47696,
  47697,
  47698,
  47699,
  47700,
  47701,
  47702,
  47703,
  47704,
  47705,
  47706,
  47707,
  47708,
  47709,
  47710,
  47711,
  47712,
  47713,
  47714,
  47715,
  47716,
  47717,
  47718,
  47719,
  47720,
  47721,
  47722,
  47723,
