In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from transformers import BertTokenizer, RobertaTokenizer, XLNetTokenizer, AdamW
from transformers import BertForSequenceClassification, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer, XLNetForSequenceClassification, XLNetTokenizer, BertConfig, RobertaConfig, XLNetConfig, AdamW
from datasets import load_dataset
import random



In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, hidden_dim, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(512, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(embed_dim, 2)

    def forward(self, input_ids, attention_mask):
        positions = torch.arange(0, input_ids.size(1)).unsqueeze(0).repeat(input_ids.size(0), 1).to(input_ids.device)
        x = self.embedding(input_ids) + self.position_embedding(positions)
        x = self.dropout(self.transformer_encoder(x.transpose(0, 1), src_key_padding_mask=~attention_mask.bool()).transpose(0, 1))
        logits = self.classifier(x[:, 0, :])
        return logits

embed_dim = 256
num_heads = 8
num_layers = 4
hidden_dim = 512
dropout = 0.1
learning_rate = 2e-5
weight_decay = 0.01
batch_size = 16
epochs = 6

def preprocess_data(data):
    texts = data['text']
    labels = data['label']
    return texts, labels

class SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer.encode_plus(
            self.texts[idx],
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids = encoding['input_ids'].flatten()
        attention_mask = encoding['attention_mask'].flatten()
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        original_text = self.texts[idx]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label,
            'original_text': original_text
        }
    

def combined_loss(logits, soft_logits_list, labels, alpha=0.5):
    kl_loss = 0
    
    for soft_logits in soft_logits_list:
        kl_loss += nn.functional.kl_div(
            torch.log_softmax(logits, dim=-1), 
            torch.softmax(soft_logits, dim=-1), 
                reduction='batchmean'
            )
    
    kl_loss /= (len(soft_logits_list))

    ce_loss = nn.functional.cross_entropy(logits, labels)

    return alpha * kl_loss + (1-alpha) * ce_loss
    
def evaluate_model(model, dataloader):
    model.eval()
    total_loss = 0
    correct_predictions = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            total_loss += loss.item()
            preds = logits.argmax(dim=-1)
            correct_predictions += (preds == labels).sum().item()
    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / len(dataloader.dataset)
    return avg_loss, accuracy

def load_model(model_class, tokenizer_class, model_path, config_class, pretrained_model_name):
    config = config_class.from_pretrained(model_path)
    model = model_class.from_pretrained(model_path, config=config)
    tokenizer = tokenizer_class.from_pretrained(pretrained_model_name)
    return model, tokenizer

def train_student(student, dataloader, teacher_model_tokenizer, valid_loader,model_name, epochs=6, alpha=0.5):
    student.train()
    student.to(device)
    optimizer = AdamW(student.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            batch_size = input_ids.size(0)

            optimizer.zero_grad()
            logits = student(input_ids, attention_mask=attention_mask)

            # Get the soft logits from the teacher models dynamically
            soft_logits_lists = []
            for teacher in teacher_model_tokenizer:
                teacher_model = teacher["model"]
                tokenizer = teacher["tokenizer"]
                teacher_model.eval()
                teacher_model.to(device)

                with torch.no_grad():
                    # Tokenize the entire batch using the teacher's tokenizer
                    encoded_batch = tokenizer(batch['original_text'], padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)
                    teacher_input_ids = encoded_batch['input_ids']
                    teacher_attention_mask = encoded_batch['attention_mask']
                    soft_logits = teacher_model(teacher_input_ids, attention_mask=teacher_attention_mask).logits
                    soft_logits_lists.append(soft_logits)

            # Calculate the mean of the soft logits from all teacher models
            mean_soft_logits = torch.mean(torch.stack(soft_logits_lists), dim=0)
            
            # Compute loss with the most common labels
            most_common_labels = []
            for i in range(batch_size):
                predictions = [torch.argmax(mean_soft_logits[i]).item()]
                most_common_label = max(set(predictions), key=predictions.count)
                most_common_labels.append(most_common_label)
            most_common_labels = torch.tensor(most_common_labels).to(device)

            total_batch_loss = combined_loss(logits, mean_soft_logits, most_common_labels, alpha)
            total_batch_loss.backward()
            optimizer.step()
            total_loss += total_batch_loss.item()

        print(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader)}')
        avg_loss, accuracy = evaluate_model(student, valid_loader)
        print(f'Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')
    # Save the model after training
    torch.save(student.state_dict(), f"{model_name}.pt")
    print(f"Model saved as {model_name}.pt")


In [3]:
# Set the seed for reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)



# Load the IMDB dataset
dataset = load_dataset('imdb')
train_texts, train_labels = preprocess_data(dataset['train'])
test_texts, test_labels = preprocess_data(dataset['test'])

# Define paths
bert_model_path = './bert_model/'
roberta_model_path = './roberta_model'
xlnet_model_path = './xlnet_model'

# Load models and tokenizers
bert_model, bert_tokenizer = load_model(BertForSequenceClassification, BertTokenizer, bert_model_path, BertConfig, "bert-base-uncased")
roberta_model, roberta_tokenizer = load_model(RobertaForSequenceClassification, RobertaTokenizer, roberta_model_path, RobertaConfig, "roberta-base")
xlnet_model, xlnet_tokenizer = load_model(XLNetForSequenceClassification, XLNetTokenizer, xlnet_model_path, XLNetConfig, "xlnet-base-cased")

# Create Datasets
train_dataset_bert = SentimentDataset(train_texts, train_labels, bert_tokenizer)
train_dataset_roberta = SentimentDataset(train_texts, train_labels, roberta_tokenizer)
train_dataset_xlnet = SentimentDataset(train_texts, train_labels, xlnet_tokenizer)


test_dataset_bert = SentimentDataset(test_texts, test_labels, bert_tokenizer)
test_dataset_roberta = SentimentDataset(test_texts, test_labels, roberta_tokenizer)
test_dataset_xlnet = SentimentDataset(test_texts, test_labels, xlnet_tokenizer)

# Create a single shuffled order of indices
indices = list(range(len(train_dataset_bert)))
random.shuffle(indices)
sampler = SubsetRandomSampler(indices)

# Create Dataloaders
train_loader_bert = DataLoader(train_dataset_bert, batch_size=batch_size, sampler=sampler)
train_loader_roberta = DataLoader(train_dataset_roberta, batch_size=batch_size, sampler=sampler)
train_loader_xlnet = DataLoader(train_dataset_xlnet, batch_size=batch_size, sampler=sampler)

test_loader_bert = DataLoader(test_dataset_bert, batch_size=batch_size, shuffle=True)
test_loader_roberta = DataLoader(test_dataset_roberta, batch_size=batch_size, shuffle=True)
test_loader_xlnet = DataLoader(test_dataset_xlnet, batch_size=batch_size, shuffle=True)

# Move models to device
roberta_model.to(device)
xlnet_model.to(device)
bert_model.to(device)


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [4]:
# Define student models
combined_student_model = TransformerModel(vocab_size=len(bert_tokenizer), embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, dropout=dropout)
student_model_bert = TransformerModel(vocab_size=len(bert_tokenizer), embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, dropout=dropout)
student_model_roberta = TransformerModel(vocab_size=len(roberta_tokenizer), embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, dropout=dropout)
student_model_xlnet = TransformerModel(vocab_size=len(xlnet_tokenizer), embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, dropout=dropout)

# Move student models to device
combined_student_model.to(device)
student_model_bert.to(device)
student_model_roberta.to(device)
student_model_xlnet.to(device)


student_dataloader = DataLoader(train_dataset_bert, batch_size=batch_size, sampler=sampler)
test_dataset_student = SentimentDataset(test_texts, test_labels, bert_tokenizer)
test_loader_student = DataLoader(test_dataset_student, batch_size=batch_size, shuffle=False)

In [5]:
teacher_model_tokeniser = [ 
    {"model":bert_model ,"tokenizer" : bert_tokenizer},
    {"model":roberta_model, "tokenizer" : roberta_tokenizer},
    {"model":xlnet_model ,"tokenizer" : xlnet_tokenizer}
    ]


print(f'combined student Model:')
train_student(combined_student_model, student_dataloader, teacher_model_tokeniser,  test_loader_student,"combined_student_Model" , epochs=epochs)


print(f'bert student Model:')
train_student(student_model_bert, student_dataloader, [{"model": bert_model, "tokenizer": bert_tokenizer}], test_loader_student, "bert_student_Model", epochs=epochs)

print(f'roberta student Model:')
train_student(student_model_roberta, student_dataloader, [{"model": roberta_model, "tokenizer": roberta_tokenizer}], test_loader_student,"roberta_student_Model", epochs=epochs)

print(f'xlnet student Model:')
train_student(student_model_xlnet, student_dataloader, [{"model": xlnet_model, "tokenizer": xlnet_tokenizer}], test_loader_student, "xlnet_student_Model",  epochs=epochs)



combined student Model:




Epoch 1/6, Loss: 0.6803152970755169
Validation Loss: 0.6070, Accuracy: 0.7078
Epoch 2/6, Loss: 0.6427579588487372
Validation Loss: 0.5586, Accuracy: 0.7486
Epoch 3/6, Loss: 0.6310390128360218
Validation Loss: 0.5392, Accuracy: 0.7712
Epoch 4/6, Loss: 0.6196214401485519
Validation Loss: 0.5085, Accuracy: 0.7855
Epoch 5/6, Loss: 0.612385183763443
Validation Loss: 0.4951, Accuracy: 0.8043
Epoch 6/6, Loss: 0.603880204608329
Validation Loss: 0.4981, Accuracy: 0.8061
Model saved as combined_student_Model.pt
bert student Model:




Epoch 1/6, Loss: 0.6823795293663376
Validation Loss: 0.6087, Accuracy: 0.7097
Epoch 2/6, Loss: 0.6457242638692593
Validation Loss: 0.5526, Accuracy: 0.7565
Epoch 3/6, Loss: 0.628658787092946
Validation Loss: 0.5292, Accuracy: 0.7803
Epoch 4/6, Loss: 0.6184460613030466
Validation Loss: 0.5044, Accuracy: 0.7947
Epoch 5/6, Loss: 0.6101442033750311
Validation Loss: 0.4872, Accuracy: 0.8147
Epoch 6/6, Loss: 0.6026526764654915
Validation Loss: 0.4942, Accuracy: 0.8007
Model saved as bert_student_Model.pt
roberta student Model:
Epoch 1/6, Loss: 0.6867016060796214
Validation Loss: 0.6417, Accuracy: 0.6198
Epoch 2/6, Loss: 0.6511006815572313
Validation Loss: 0.5620, Accuracy: 0.7356
Epoch 3/6, Loss: 0.6343808185573732
Validation Loss: 0.5322, Accuracy: 0.7728
Epoch 4/6, Loss: 0.6228217849613991
Validation Loss: 0.5174, Accuracy: 0.7926
Epoch 5/6, Loss: 0.6146679427176809
Validation Loss: 0.4958, Accuracy: 0.8073
Epoch 6/6, Loss: 0.608493223605214
Validation Loss: 0.4887, Accuracy: 0.8093
Model 