In [None]:
%%writefile SHA_Diagonal/config.py

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

class CONFIG:
    output_dir = "SHA-DIAG"
    task = "rte"
    seed = 42
    max_len = 128
    train_batch = 32
    valid_batch = 32
    epochs = 40
    learning_rate = 1e-2
    classifier_learning_rate = 1e-2  # Different learning rate for the classifier head
    warmup_ratio = 0.06
    model_name = "FacebookAI/roberta-large"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
%%writefile SHA_Diagonal/GLUE_data_setup.py

import config
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset

class GLUEDataset(Dataset):
    def __init__(self, dataset_name=config.CONFIG.task, split="train", tokenizer_name=config.CONFIG.model_name, max_len=config.CONFIG.max_len):
        self.dataset_name = dataset_name
        if self.dataset_name == "sst2":
            self.dataset = load_dataset(dataset_name)[split].to_pandas()
        else:    
            self.dataset = load_dataset('glue', dataset_name)[split].to_pandas()

        self.max_len = max_len
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

        if dataset_name in ['sst2', 'cola']:
            self.text = self.dataset['sentence'].values
            self.labels = self.dataset['label'].values
        elif dataset_name in ['mrpc', 'qqp', 'stsb', 'rte']:
            self.sentence1 = self.dataset['sentence1'].values
            self.sentence2 = self.dataset['sentence2'].values
            self.labels = self.dataset['label'].values
        elif dataset_name == 'mnli':
            self.premises = self.dataset['premise'].values
            self.hypotheses = self.dataset['hypothesis'].values
            self.labels = self.dataset['label'].values
        elif dataset_name == 'qnli':
            self.questions = self.dataset['question'].values
            self.sentences = self.dataset['sentence'].values
            self.labels = self.dataset['label'].values

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        if self.dataset_name in ['sst2', 'cola']:
            text = self.text[index]
            text = ' '.join(text.split())
            inputs = self.tokenizer.encode_plus(
                text,
                None,
                truncation=True,
                add_special_tokens=True,
                max_length=self.max_len,
                padding='max_length',
                return_token_type_ids=True
            )
        elif self.dataset_name in ['mrpc', 'qqp', 'stsb', 'rte']:
            sentence1 = self.sentence1[index]
            sentence2 = self.sentence2[index]
            sentence1 = ' '.join(sentence1.split())
            sentence2 = ' '.join(sentence2.split())
            inputs = self.tokenizer.encode_plus(
                sentence1,
                sentence2,
                truncation=True,
                add_special_tokens=True,
                max_length=self.max_len,
                padding='max_length',
                return_token_type_ids=True
            )
        elif self.dataset_name == 'mnli':
            premise = self.premises[index]
            hypothesis = self.hypotheses[index]
            premise = ' '.join(premise.split())
            hypothesis = ' '.join(hypothesis.split())
            inputs = self.tokenizer.encode_plus(
                premise,
                hypothesis,
                truncation=True,
                add_special_tokens=True,
                max_length=self.max_len,
                padding='max_length',
                return_token_type_ids=True
            )
        elif self.dataset_name == 'qnli':
            question = self.questions[index]
            sentence = self.sentences[index]
            question = ' '.join(question.split())
            sentence = ' '.join(sentence.split())
            inputs = self.tokenizer.encode_plus(
                question,
                sentence,
                truncation=True,
                add_special_tokens=True,
                max_length=self.max_len,
                padding='max_length',
                return_token_type_ids=True
            )

        inputs['input_ids'] = torch.tensor(inputs['input_ids'], dtype=torch.long)
        inputs['attention_mask'] = torch.tensor(inputs['attention_mask'], dtype=torch.long)

        if 'token_type_ids' in inputs:
            inputs['token_type_ids'] = torch.tensor(inputs['token_type_ids'], dtype=torch.long)

        label = torch.tensor(self.labels[index], dtype=torch.long if self.dataset_name != 'stsb' else torch.float)

        result = {
            "input_ids": inputs['input_ids'],
            "attention_mask": inputs['attention_mask'],
            "labels": label
        }

        if 'token_type_ids' in inputs:
            result["token_type_ids"] = inputs['token_type_ids']

        return result


In [None]:
%%writefile SHA_Diagonal/peft_module.py

import math
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer

class SHA_DIAGONAL(nn.Module):

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        r: int = 24,
    ):
        super().__init__()
        self.r = r

        # recreate the linear layer and freeze it (the actual weight values will be copied in outside of this class)
        self.pretrained = nn.Linear(in_dim, out_dim, bias=True)
        self.pretrained.weight.requires_grad = False

        # create the down projection matrix and initialize with same method as in Hugging Face PEFT library
        self.down_proj = nn.Linear(in_dim, r, bias=False)
        #nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
        
        self.Wqkv = nn.Linear(r, (r // 4)*3, bias=False)
        #nn.init.kaiming_uniform_(self.Wqkv.weight, a=math.sqrt(1))
        
        self.Wo = nn.Linear(r // 4, r, bias=False)
        #nn.init.kaiming_uniform_(self.Wo.weight, a=math.sqrt(2))
        
        # create the up projection matrix and initialize to zero
        self.up_proj = nn.Linear(r, out_dim, bias=False)
        #nn.init.kaiming_uniform_(self.up_proj.weight, a=math.sqrt(5))

        # Add the custom DiagonalLinear layer
        self.diagonal_linear_b = nn.Parameter(torch.zeros(out_dim), requires_grad=True)
        #nn.init.constant_(self.diagonal_linear_b, 0.01)


    def forward(self, x):
        pretrained_out = self.pretrained(x)

        down_project_out = self.down_proj(x)

        B, S, C = down_project_out.shape

        q, k, v = self.Wqkv(down_project_out).reshape(B, S, 3, C//4).unbind(dim=2)
        
        mini_attn_output = q @ k.transpose(-2, -1)
        mini_attn_output = mini_attn_output / math.sqrt(k.size(-1))

        mini_attn_output = mini_attn_output.softmax(dim=-1)

        mini_attn_output = mini_attn_output @ v

        mini_attn_output = self.Wo(mini_attn_output)

        up_project_out = self.up_proj(mini_attn_output)

        diagonal_b_out = up_project_out * self.diagonal_linear_b

        return pretrained_out + diagonal_b_out



def freeze_model(model):
    for name, param in model.named_parameters():
        if "Wqkv" not in name and "Wo" not in name and "diagonal_linear_b" not in name and "classifier" not in name:
            param.requires_grad = False


def create_peft(module):
    """Converts a linear module to a peft linear module."""
    k, d = module.weight.shape  # pytorch nn.Linear weights are transposed, that is why shape is (k, d) and not (d, k)
    peft = SHA_DIAGONAL(in_dim=d, out_dim=k)
    with torch.no_grad():
        peft.pretrained.weight.copy_(module.weight)
        peft.pretrained.bias.copy_(module.bias)
    return peft   



def add_peft_layers(
    model,
    module_names: Tuple=("query", "value"),
    ignore_layers: List[int]=[]
):
    module_types: Tuple=(nn.Linear,)

    # disable dropout in frozen layers
    for module in model.modules():
        if isinstance(module, nn.Dropout):
            module.p = 0.0
    # replace chosen linear modules with lora modules
    model_name
        if isinstance(module, module_types) and name in module_names:
            temp_peft = create_peft(module)
            setattr(model, name, temp_peft)
        else:
            ignore_layers_str = [str(i) for i in ignore_layers]
            if name not in ignore_layers_str:
                add_peft_layers(module, module_names, ignore_layers)             

In [None]:
%%writefile SHA_Diagonal/engine.py

import transformers
from transformers import AdamW
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.metrics import f1_score, accuracy_score
import config
import peft_module


def eval_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    return accuracy_score(labels, preds_flat)    



def evaluate(model, val_dataloader):

    model.eval()

    loss_val_total = 0
    predictions, true_vals = [], []

    for batch in val_dataloader:


        inputs = {'input_ids':      batch['input_ids'].to(config.CONFIG.device),
                  'attention_mask': batch['attention_mask'].to(config.CONFIG.device),
                  'labels':         batch['labels'].to(config.CONFIG.device),
                 }

        with torch.no_grad():
            outputs = model(**inputs)

        loss = outputs["loss"]
        logits = outputs["logits"]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)

    loss_val_avg = loss_val_total/len(val_dataloader)

    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)

    return loss_val_avg, predictions, true_vals    




def train(model, optimizer, scheduler, train_dataloader, val_dataloader):

    epochs = config.CONFIG.epochs
    model.to(config.CONFIG.device)

    for epoch in tqdm(range(1, epochs+1)):


      model.train()

      loss_train_total = 0

      progress_bar = tqdm(train_dataloader, desc='Epoch {:1d}'.format(epoch), leave=False, disable=True)

      for batch in progress_bar:

        optimizer.zero_grad()

        inputs = {'input_ids':      batch['input_ids'].to(config.CONFIG.device),
                  'attention_mask': batch['attention_mask'].to(config.CONFIG.device),
                  'labels':         batch['labels'].to(config.CONFIG.device),
                }

        output = model(**inputs)

        loss = output["loss"]
        loss_train_total += loss.item()
        loss.backward()


        optimizer.step()
        scheduler.step()

        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})


      tqdm.write(f'\nEpoch {epoch}')
      loss_train_avg = loss_train_total/len(train_dataloader)
      tqdm.write(f'Training loss: {loss_train_avg}')


      val_loss, predictions, true_vals = evaluate(model, val_dataloader)
      val_f1 = eval_func(predictions, true_vals)
      tqdm.write(f'Validation loss: {val_loss}')
      tqdm.write(f'Accuracy : {val_f1}')


In [None]:
%%writefile SHA_Diagonal/train.py

import config, GLUE_data_setup, peft_module, engine
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

task = config.CONFIG.task
train_dataset = GLUE_data_setup.GLUEDataset(dataset_name=config.CONFIG.task, split = "train")
validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"
validation_dataset = GLUE_data_setup.GLUEDataset(dataset_name=config.CONFIG.task, split = validation_key)

train_loader = DataLoader(train_dataset, batch_size=config.CONFIG.train_batch,
                              num_workers=1, shuffle=True, pin_memory=True)

validation_loader = DataLoader(validation_dataset, batch_size=config.CONFIG.valid_batch,
                              num_workers=1, shuffle=False, pin_memory=True)

torch.manual_seed(config.CONFIG.seed)
torch.cuda.manual_seed(config.CONFIG.seed)
num_labels = 3 if config.CONFIG.task.startswith("mnli") else 1 if config.CONFIG.task=="stsb" else 2
model = AutoModelForSequenceClassification.from_pretrained(config.CONFIG.model_name, num_labels = num_labels, output_attentions = False,
                                                           output_hidden_states = False).to(config.CONFIG.device)

peft_module.add_peft_layers(model=model) 
peft_module.freeze_model(model)


# Identify classifier parameters and other parameters
classifier_parameters = [p for n, p in model.named_parameters() if "classifier" in n and p.requires_grad]
other_parameters = [p for n, p in model.named_parameters() if "classifier" not in n and p.requires_grad]

# Create optimizer with parameter groups
optimizer = torch.optim.AdamW([
    {'params': other_parameters, 'lr': config.CONFIG.learning_rate},
    {'params': classifier_parameters, 'lr': config.CONFIG.classifier_learning_rate}
])

# Define the total number of training steps and warm-up steps
total_steps = len(train_loader) * config.CONFIG.epochs
warmup_steps = int(config.CONFIG.warmup_ratio * total_steps)

# Create the scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

engine.train(model=model, optimizer=optimizer, scheduler=scheduler, train_dataloader=train_loader, val_dataloader=validation_loader)

In [3]:
%%writefile /home/azimi/SHA_Diagonal/utils.py
def num_parameters(model):
    total_params_1 = 0
    total_params_2 = 0
    for param_name, weights in model.named_parameters():
      if weights.requires_grad == True:
        total_params_1 += weights.numel()

    for param_name, weights in model.named_parameters():
      if 'classifier' in param_name:
        total_params_2 += weights.numel()

    print("total_params:", total_params_1-total_params_2)    

Writing /home/azimi/SHA_Diagonal/utils.py
