In [1]:
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import random
import torch.nn.functional as F

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

expert1_count = torch.zeros(32, dtype = int)
expert2_count = torch.zeros(32, dtype = int)

In [3]:
class GatingMechanism(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(GatingMechanism, self).__init__()
        self.gate = nn.Linear(input_dim, num_experts).to(device)

    def forward(self, x):
        x_mean = x.mean(dim=1)
        gate_scores = F.softmax(self.gate(x_mean), dim=-1)  # Shape: [batch_size, num_experts]
        return gate_scores.argmax(dim=-1)  # Shape: [batch_size]


In [4]:
import torch
import torch.nn as nn

class MoEModelWithPooling(nn.Module):
    def __init__(self, experts, input_dim):
        super().__init__()
        self.experts = nn.ModuleList(experts)
        self.num_layers = len(experts[0].base_model.model.model.layers)  # Correct path to access layers
        self.gating = nn.ModuleList([GatingMechanism(input_dim, len(experts)) for _ in range(self.num_layers)])
        self.pooling = nn.AdaptiveAvgPool1d(1).to(device)  # Example pooling layer
        self.output_layer = nn.Linear(4096, 4).to(device)
        # self.softmax = nn.Softmax(dim=1).to(device)

    def forward(self, input_ids, attention_mask, training = True):
        global expert2_count
        global expert1_count

        x = self.experts[0].base_model.model.model.embed_tokens(input_ids)

        for i in range(self.num_layers):
            expert_indices = self.gating[i](x)
            
            if training == False:
                expert2_count[i] += int(expert_indices.sum())
                expert1_count[i] += int(expert_indices.shape[0] - expert_indices.sum())
                # print(expert1_count, expert2_count)

            layer_output = torch.zeros_like(x)

            for idx, expert in enumerate(self.experts):
                mask = (expert_indices == idx).unsqueeze(-1).unsqueeze(1).half()

                expert_input = x * mask

                expert_output = expert.base_model.model.model.layers[i](expert_input, attn_mask=attention_mask)[0]
                layer_output += expert_output * mask

            x = layer_output

        x = x.transpose(1, 2)  # Adjust dimensions for pooling
        x = self.pooling(x).squeeze(2)
        x = self.output_layer(x)

        return x

# GatingMechanism definition assumed to be implemented elsewhere


In [5]:
# # %%capture
# # Installs Unsloth, Xformers (Flash Attention) and all other packages!
# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes

In [6]:
# Load pre-trained models
from unsloth import FastLanguageModel

max_seq_length = 256 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model1, tokenizer = FastLanguageModel.from_pretrained("unsloth_domain2",
                                                     max_seq_length=max_seq_length,
                                                     dtype=dtype,
                                                     load_in_4bit=load_in_4bit)

model2, tokenizer = FastLanguageModel.from_pretrained("ai2_arc_instruction_tuned_mistral_7b",
                                                     max_seq_length=max_seq_length,
                                                     dtype=dtype,
                                                     load_in_4bit=load_in_4bit)

config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Unused kwargs: ['quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


==((====))==  Unsloth: Fast Mistral patching release 2024.5
   \\   /|    GPU: NVIDIA A40. Max memory: 44.352 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


model.safetensors:   0%|          | 0.00/4.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Unsloth 2024.5 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


config.json:   0%|          | 0.00/1.07k [00:00<?, ?B/s]

Unused kwargs: ['quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


==((====))==  Unsloth: Fast Mistral patching release 2024.5
   \\   /|    GPU: NVIDIA A40. Max memory: 44.352 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


model.safetensors:   0%|          | 0.00/4.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [7]:
# for param in model1.parameters():
#     param.requires_grad = False

# for param in model2.parameters():
#     param.requires_grad = False
    
models = [model1, model2]

In [8]:
num_layers = len(model1.base_model.model.model.layers)

moe_model = MoEModelWithPooling(models, input_dim=4096)

In [9]:
# print("Total params ", sum(p.numel() for p in model1.parameters()))
print("Total trainable params ", sum(p.numel() for p in moe_model.parameters() if p.requires_grad))      

# # a = 0
# # for param in model1.parameters():
# #     if param.requires_grad == True:
# #         a += sum(param.numel())
# # print(a)

Total trainable params  84164676


In [10]:
# for name, param in moe_model.named_parameters():
#     if param.requires_grad == True:
#         print(name)

In [11]:
print(moe_model)

MoEModelWithPooling(
  (experts): ModuleList(
    (0-1): 2 x PeftModelForCausalLM(
      (base_model): LoraModel(
        (model): MistralForCausalLM(
          (model): MistralModel(
            (embed_tokens): Embedding(32000, 4096)
            (layers): ModuleList(
              (0-31): 32 x MistralDecoderLayer(
                (self_attn): MistralSdpaAttention(
                  (q_proj): lora.Linear4bit(
                    (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Identity()
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=4096, out_features=16, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=16, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (

In [12]:
# from torch.cuda.amp import GradScaler

# scaler = GradScaler()

# temp = torch.ones((2,8), dtype=torch.int64).to(device)
# criterion = torch.nn.CrossEntropyLoss()
# labels = torch.rand(2, 4).float().to(device)

# optimizer = torch.optim.Adam(moe_model.parameters(), lr=1e-3)

# optimizer.zero_grad()
# with torch.cuda.amp.autocast():
#     output = moe_model(temp).float()
#     loss = criterion(output, labels.float())

# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()

In [13]:
tokenizer.pad_token = tokenizer.eos_token

In [14]:
from datasets import load_dataset, load_from_disk, concatenate_datasets, Dataset

dataset_location = 'medmcqa-prompts'

train_dataset = load_from_disk(f"{dataset_location}/train_prompts.hf")
# test_dataset = load_from_disk(f"{dataset_location}/test_prompts.hf")
eval_dataset = load_from_disk(f"{dataset_location}/eval_prompts.hf")

In [15]:
# train = []
# val = []
# count = 0
# for i in train_dataset:
#     train.append(i)
#     count += 1
#     if count >= 100:
#         break

# count = 0
# for i in eval_dataset:
#     val.append(i)
#     count += 1
#     if count >= 100:
#         break

# train_dataset = ''
# eval_dataset = ''

In [16]:
# print(train[0])

In [17]:
from torch.utils.data import DataLoader, Dataset

class MCQDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)  # Changed to float for one-hot encoding
        return item

    def __len__(self):
        return len(self.labels)

# Function to encode the data
def encode_data(tokenizer, prompts):
    encodings = tokenizer(prompts, truncation=True, padding=True, max_length = 128)
    return encodings

# Prepare the data for tokenization
prompts = [item['prompt'] for item in train_dataset]
labels = [item['label_one_hot'] for item in train_dataset]  # one-hot encoded labels

# Tokenize data
encodings = encode_data(tokenizer, prompts)

# Create dataset
train_set = MCQDataset(encodings, labels)

# DataLoader
train_loader = DataLoader(train_set, batch_size=16, shuffle=True)


prompts = [item['prompt'] for item in eval_dataset]
labels = [item['label_one_hot'] for item in eval_dataset]  # one-hot encoded labels

# Tokenize data
encodings = encode_data(tokenizer, prompts)

# Create dataset
eval_set = MCQDataset(encodings, labels)

# DataLoader
val_loader = DataLoader(eval_set, batch_size=16, shuffle=True)

In [18]:
for i, batch in enumerate(train_loader):
    print("Batch", i)
    print("Input IDs:", batch['input_ids'].shape)
    print("Attention Mask:", batch['attention_mask'].shape)
    print("Labels:", batch['labels'].shape)

    # Print the actual content of the first example in the batch
    if i == 0:
        print("First input IDs example:", batch['input_ids'][0])
        print("First attention mask example:", batch['attention_mask'][0])
        print("First label example:", batch['labels'][0])

    # Optionally, break after a few batches to avoid too much output
    if i == 2:
        break

Batch 0
Input IDs: torch.Size([16, 128])
Attention Mask: torch.Size([16, 128])
Labels: torch.Size([16, 4])
First input IDs example: tensor([    1, 28705,    13,  2287, 22478, 28747,    13,  2287,   330, 28705,
        28740, 28782, 28733,  4395, 28733,   738,  4531,  7567,   395, 28705,
        28740,  1370,  3340,   302, 25352,   319,  7610, 28725,  1083,   514,
        28768, 18181,  3098,  8012,   286,   304,  3276, 14692,   294,   408,
         1029, 28723, 11606,   326,   697, 10924, 28747,   382, 28726, 28705,
        28784, 28723, 28781,   319, 28719, 28748, 28715, 28758, 28745,   320,
         9162, 28733, 28750, 28784, 28725, 28782, 28734, 28734, 28748,  3221,
        28770, 28745,   430,   362,   436,  5721,   727, 28733, 28750, 28734,
         5267,   395,   264,  2602,   302, 28705, 28740, 28770,  5267, 28745,
        10473,   306,   436, 28726,   410,  4081,   262,   727, 28733, 28782,
        28734,  5267, 28745,   304,   285,  2792,   262,  8371, 28705, 28740,
        28

In [None]:
import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaler = GradScaler()

def print_memory_usage():
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

def compute_loss(outputs, labels, attention_mask):
    # Flatten the outputs and labels for loss calculation
    active_loss = attention_mask.view(-1) == 1  # Mask out padded tokens
    active_logits = outputs.view(-1, outputs.size(-1))[active_loss]
    active_labels = labels.view(-1)[active_loss]
    return F.cross_entropy(active_logits, active_labels)

def train_and_validate(model, train_loader, val_loader, log_file_path, epochs=15):
    scaler = GradScaler()
    device = torch.device("cuda")
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, eps=1e-8)
    criterion = torch.nn.CrossEntropyLoss()
    
    best_val_accuracy = 0
    
    with open(log_file_path, 'a') as log_file:
        log_file.write("Starting training process...\n")
        log_file.flush()
        for epoch in range(epochs):
            total_train_loss = 0
            total_train_correct = 0
            train_samples = 0

            model.train()
            train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [TRAIN]", unit="batch")
            for i, batch in enumerate(train_pbar):
                input_ids, labels, attention_mask = batch['input_ids'].to(device), batch['labels'].to(device), batch['attention_mask'].to(dtype=bool, device=device)
                train_samples += labels.size(0)
#                 mask = attention_mask.unsqueeze(1)

#                 mask = mask.expand(-1, 128, -1)

                optimizer.zero_grad()
                with torch.cuda.amp.autocast():
                    # print(attention_mask.shape)
                    output = model(input_ids, attention_mask, True).float()
                    loss = criterion(output, labels.float())
                    predictions = torch.argmax(F.softmax(output,dim=1), dim=1)
                    labels_indices = torch.argmax(labels, dim=1)
                    total_train_correct += (predictions == labels_indices).sum().item()

                scaler.scale(loss).backward()
                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                total_train_loss += loss.item()
                
                log_file.write(f"Batch {i}, Epoch {epoch+1}, Training Loss: {loss.item():.4f}, Training Accuracy: {100 * total_train_correct / train_samples:.2f}%\n")
                log_file.flush()
                
                train_pbar.set_postfix(loss=loss.item(), temp_acc=100 * total_train_correct / train_samples)

                if i % 1000 == 0:
                    print(i, loss.item())
                    print(f"Temp accuracy: ", total_train_correct / train_samples * 100)

            avg_train_loss = total_train_loss / len(train_loader)
            train_accuracy = total_train_correct / train_samples * 100
            print(f"Training Accuracy: ", train_accuracy)
            print(f"Epoch {epoch+1}, Loss: {avg_train_loss}")

            model.eval()
            total_val_loss, val_samples, total_val_correct = 0, 0, 0
            with torch.no_grad():
                for i, batch in enumerate(val_loader):
                    input_ids, labels, attention_mask = batch['input_ids'].to(device), batch['labels'].to(device), batch['attention_mask'].to(device)
                    with torch.cuda.amp.autocast():
                        output = model(input_ids, attention_mask, False).float()
                        val_loss = criterion(output, labels.float())
                        predictions = torch.argmax(F.softmax(output,dim=1), dim=1)
                        labels_indices = torch.argmax(labels, dim=1)
                        total_val_correct += (predictions == labels_indices).sum().item()

                    total_val_loss += val_loss.item()
                    val_samples += labels.size(0)
                    log_file.write(f"Batch {i}, Epoch {epoch+1}, Validation Loss: {loss.item():.4f}, Validation Accuracy: {100 * total_val_correct / val_samples:.2f}% Expert 1 - {expert1_count}, Expert 2 - {expert2_count}\n")
                    log_file.flush()
                    
            avg_val_loss = total_val_loss / len(val_loader)
            val_accuracy = total_val_correct / val_samples * 100
            print(f"Validation Accuracy: ", val_accuracy)
            print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f}") 
            
            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                best_model_path = "best_model_full2.pth"
                torch.save(model.state_dict(), best_model_path)
                print(f"New best model saved with accuracy: {best_val_accuracy:.2f}% at {best_model_path}")

# Example usage
train_and_validate(moe_model, train_loader, val_loader, 'training_final_mega2.txt')


Epoch 1 [TRAIN]:   0%|          | 1/11427 [00:03<9:35:59,  3.02s/batch, loss=1.42, temp_acc=12.5]

0 1.4234427213668823
Temp accuracy:  12.5


Epoch 1 [TRAIN]:   9%|▉         | 1001/11427 [34:54<6:02:08,  2.08s/batch, loss=1.38, temp_acc=26.5]

1000 1.3806685209274292
Temp accuracy:  26.523476523476525


Epoch 1 [TRAIN]:  18%|█▊        | 2001/11427 [1:09:41<5:27:05,  2.08s/batch, loss=1.55, temp_acc=27.1]

2000 1.546600103378296
Temp accuracy:  27.09895052473763


Epoch 1 [TRAIN]:  26%|██▋       | 3001/11427 [1:44:26<4:51:34,  2.08s/batch, loss=3.66, temp_acc=27]  

3000 3.6552422046661377
Temp accuracy:  26.99933355548151


Epoch 1 [TRAIN]:  35%|███▌      | 4001/11427 [2:19:03<4:16:44,  2.07s/batch, loss=1.78, temp_acc=26.7]

4000 1.7825253009796143
Temp accuracy:  26.69645088727818


Epoch 1 [TRAIN]:  39%|███▉      | 4459/11427 [2:34:55<4:01:01,  2.08s/batch, loss=1.98, temp_acc=26.6]