In [1]:
from transformers import OpenAIGPTTokenizer
import torch
from torch import nn, optim
import torch.nn.functional as F
import math
import pandas as pd
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [2]:
class QQPDataset(Dataset):

    def __init__(self, csv_path: str):
        self.df = pd.read_csv(csv_path)
        self.df = self.df.drop(["id", "qid1", "qid2"], axis=1)
        self.df.columns = ["q1", "q2", "label"]
        self.df = self.df.dropna()

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        q1, q2, label = self.df.iloc[idx]
        order1 = f"<bos> {q1} <sep> {q2} <eos>"
        order2 = f"<bos> {q2} <sep> {q1} <eos>"
        return order1, order2, label
    

class TokenizeCollate:

    def __init__(self, tokenizer_obj):
        self.tokenizer = tokenizer_obj

    def __call__(self, x):
        q1, q2, labels = [], [], []
        for q1_item, q2_item, label in x:
            q1.append(q1_item)
            q2.append(q2_item)
            labels.append(label)
        q1 = self.tokenizer(q1, return_tensors='pt', padding=True)
        q2 = self.tokenizer(q2, return_tensors='pt', padding=True)
        labels = torch.tensor(labels).type(torch.float32)
        return q1["input_ids"], q2["input_ids"], q1["attention_mask"].type(torch.bool), q2["attention_mask"].type(torch.bool), labels

In [3]:
def prepare_mask(padding_mask, causal=True):
    padding_mask = padding_mask.unsqueeze(1).unsqueeze(-2)
    if causal:
        causal = torch.tril(torch.ones(padding_mask.shape[-1], padding_mask.shape[-1])).type(torch.bool)
        padding_mask = padding_mask * causal.to(padding_mask.device)
    return padding_mask

In [4]:
class Attention(nn.Module):

    def __init__(self, input_dim, output_dim, num_heads, p):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.c_attn = nn.Linear(input_dim, output_dim * 3)
        self.c_proj = nn.Linear(output_dim, output_dim)
        self.attn_dropout = nn.Dropout(p=p, inplace=False)
        self.resid_dropout = nn.Dropout(p=p, inplace=False)

    def scaled_dot_product_attention(self, q, k, v, mask):
        """
        q: [batch_size, num_heads, head_dim, seq1_len]
        k: [batch_size, num_heads, head_dim, seq2_len]
        v: [batch_size, num_heads, head_dim, seq2_len]
        mask: [batch_size, num_heads, seq1_len, seq1_len]
        (seq1_len = seq2_len for self attention)
        """
        qk = q.matmul(k.transpose(-1, -2)) / math.sqrt(q.shape[-1])
        if mask is not None:
            qk = qk.masked_fill(~mask, -torch.inf)
        attn_weights = self.attn_dropout(qk.softmax(dim=-1))
        return attn_weights.matmul(v)
    
    def qkv_reshape(self, x):
        return x.view(x.shape[0], x.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)
    
    def output_reshape(self, x):
        x = x.permute(0, 2, 1, 3)
        return x.reshape(x.shape[0], x.shape[1], -1)
    
    def forward(self, x, mask):
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        q, k, v = self.qkv_reshape(q), self.qkv_reshape(k), self.qkv_reshape(v)
        attn_outputs = self.output_reshape(self.scaled_dot_product_attention(q, k, v, mask))
        return self.resid_dropout(self.c_proj(attn_outputs))
    

class MLP(nn.Module):

    def __init__(self, input_dim, p) -> None:
        super(MLP, self).__init__()
        self.c_fc = nn.Linear(input_dim, input_dim * 4)
        self.c_proj = nn.Linear(input_dim * 4, input_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(p=p, inplace=False)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.dropout(x)
        x = self.act(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):

    def __init__(self, d_model, num_heads, p):
        super(Block, self).__init__()
        self.attn = Attention(d_model, d_model, num_heads, p)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, p)
        self.ln_2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        skip_x = x
        x = self.attn(x, mask=mask)
        x = self.ln_1(x + skip_x)
        skip_x = x
        x = self.mlp(x)
        x = self.ln_2(x + skip_x)
        return x
    

class GPT(nn.Module):

    def __init__(self, vocab_size, max_seq_len, n_layers, d_model, num_heads, p):
        super(GPT, self).__init__()
        self.d_model, self.max_seq_len = d_model, max_seq_len
        self.tokens_embed = nn.Embedding(vocab_size, d_model)
        self.positions_embed = nn.Embedding(max_seq_len, d_model)
        self.drop = nn.Dropout(p=p, inplace=False)
        self.h = nn.ModuleList([Block(d_model, num_heads, p) for _ in range(n_layers)])

    def forward(self, x, mask=None):
        """
        x: [batch_size, seq_len]
        """
        x = self.tokens_embed(x) * math.sqrt(self.d_model)
        position_tokens = torch.arange(x.shape[-2]).unsqueeze(0).repeat(x.shape[0], 1).to(x.device)
        x = self.drop(x + self.positions_embed(position_tokens))
        for layer in self.h:
            x = layer(x, mask=mask)
        return x
    

class GPTSemanticSimilarity(nn.Module):

    def __init__(self, gpt_base: GPT):
        super(GPTSemanticSimilarity, self).__init__()
        self.gpt_base = gpt_base
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.classifier = nn.Linear(self.gpt_base.d_model * self.gpt_base.max_seq_len, 1)

    def forward(self, x1, x2, x1_mask, x2_mask):
        x1 = self.gpt_base(x1, prepare_mask(x1_mask))
        x2 = self.gpt_base(x2, prepare_mask(x2_mask))
        x = self.dropout(x1 + x2)
        padding = torch.zeros(x.shape[0], self.gpt_base.max_seq_len - x.shape[1], x.shape[-1]).to(x.device)
        x = torch.cat([x, padding], dim=1).view(x.shape[0], -1)
        return self.classifier(x).view(-1)

In [5]:
def init_finetuning_model_and_tokenizer(weights_path, device, freeze_base_weights=False):
    tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
    gpt = GPT(
        vocab_size=tokenizer.vocab_size,
        max_seq_len=512,
        d_model=768,
        num_heads=12,
        n_layers=12,
        p=0.1
    ).to(device)
    gpt.load_state_dict(torch.load(weights_path, map_location=device))
    
    special_tokens_dict = {"bos_token":"<bos>", "eos_token":"<eos>", "sep_token":"<sep>", "pad_token":"<pad>"}
    num_added = tokenizer.add_special_tokens(special_tokens_dict)
    new_embedding_weights = torch.randn(num_added, 768).to(device)
    gpt.tokens_embed.weight.data = torch.cat([gpt.tokens_embed.weight.data, new_embedding_weights], dim=0)

    if freeze_base_weights:
        for p in gpt.parameters():
            p.requires_grad = False
        # gpt.tokens_embed.weight.requires_grad = True

    gpt = GPTSemanticSimilarity(gpt).to(device)
    return tokenizer, gpt

In [6]:
DEV = torch.device("mps")
BATCH_SIZE = 32
LR = 6.25e-5
EPOCHS = 5

WEIGHTS_PATH = "weights.pth"
DATASET_PATH = "dataset/train.csv"

In [7]:
tokenizer, model = init_finetuning_model_and_tokenizer(WEIGHTS_PATH, DEV, freeze_base_weights=True)
dataset = QQPDataset(DATASET_PATH)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=TokenizeCollate(tokenizer))

In [8]:
crit = nn.BCEWithLogitsLoss()
opt = optim.Adam(model.parameters(), lr=LR)

In [9]:
for e in range(1, EPOCHS + 1):
    loop = tqdm(enumerate(loader), total=len(loader), leave=True, position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    total_loss = 0
    for i, (x1, x2, x1_mask, x2_mask, labels) in loop:
        x1, x2, x1_mask, x2_mask, labels = x1.to(DEV), x2.to(DEV), x1_mask.to(DEV), x2_mask.to(DEV), labels.to(DEV)
        opt.zero_grad()
        yhat = model(x1, x2, x1_mask, x2_mask)
        loss = crit(yhat, labels)
        loss.backward()
        opt.step()

        total_loss += loss.item()
        loop.set_postfix(loss = total_loss / (i + 1))

Epoch : [1/5]:   0%|          | 0/12634 [00:00<?, ?it/s]

Epoch : [1/5]:   0%|          | 17/12634 [00:08<1:50:50,  1.90it/s, loss=0.781]


KeyboardInterrupt: 