In [1]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from numpy import sqrt, log
from dataset import TokenTypesDataset

In [3]:
train_dataset = TokenTypesDataset(folder="../tokentype_keywordasitis_data/train")
val_dataset = TokenTypesDataset(folder="../tokentype_keywordasitis_data/validation", train=False, vocabs=(train_dataset.token2idx, train_dataset.idx2token), max_length=train_dataset.max_length)
test_dataset = TokenTypesDataset(folder="../tokentype_keywordasitis_data/test", train=False, vocabs=(train_dataset.token2idx, train_dataset.idx2token), max_length=train_dataset.max_length)
assert val_dataset.vocab_size == train_dataset.vocab_size == test_dataset.vocab_size
assert val_dataset.max_length == train_dataset.max_length == test_dataset.max_length


In [4]:
def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_masks(src, pad_idx, device):
    src_seq_len = src.shape[1]

    src_mask = generate_square_subsequent_mask(src_seq_len, device=device)

    src_padding_mask = (src == pad_idx)
    return src_mask, src_padding_mask

In [5]:
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import Phi3Config, Phi3ForCausalLM

def calc_accuracy_and_hitrate_at_k(model: Phi3ForCausalLM, loader, device, k=3):
    
    total_seq_len_in_dataloader = 0
    total_correct_predictions = 0
    total_hits_in_loader = 0
    
    for batch in tqdm(loader):
        src_mask, src_pad_mask = (create_masks(batch, pad_idx=loader.dataset.pad_id, device=device))
        src_mask = src_mask.to(device)
        src_pad_mask = src_pad_mask.to(device)
        batch = batch.to(device)
    
        labels = batch[:, 1:]
        labels_pad_mask = (labels == 0)
        labels_without_pad = labels[~labels_pad_mask]
        logits_without_last = model.forward(input_ids=batch, attention_mask=src_pad_mask).logits[:, :-1, :]
        predictions = logits_without_last.argmax(dim=-1)
        predictions_without_pad = predictions[~labels_pad_mask]

        top_k_predictions = torch.argsort(logits_without_last, dim=-1, descending=True)[:, :, :k][~labels_pad_mask]

    
        total_seq_len_in_batch = labels_without_pad.shape[0]
        total_predicions_len = predictions_without_pad.shape[0]
    
        assert total_seq_len_in_batch == total_predicions_len
    
        total_seq_len_in_dataloader += total_seq_len_in_batch
    
        correct_predictions = (labels_without_pad == predictions_without_pad).float().sum()

        total_hits_in_batch = (top_k_predictions == labels_without_pad.unsqueeze(1)).any(dim=1).float().sum()
        
        total_hits_in_loader += total_hits_in_batch
        total_correct_predictions += correct_predictions
    
    return (total_correct_predictions / total_seq_len_in_dataloader).item(), (total_hits_in_loader / total_seq_len_in_dataloader).item()

def train_epoch(model: Phi3ForCausalLM, optimizer, loss_fn, train_dataloader: DataLoader, device):
    model.train()
    losses = 0

    for src in tqdm(train_dataloader, leave=False):
        src = src.to(device)

        src_mask, src_padding_mask = create_masks(src, pad_idx=train_dataloader.dataset.pad_id, device=device)

        logits = model.forward(input_ids=src, attention_mask=src_padding_mask).logits[:, :-1, :]

        src_out = src[:, 1:]

        optimizer.zero_grad()
        
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), src_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model: Phi3ForCausalLM, loss_fn, val_dataloader, device, k=3):
    model.eval()
    losses = 0

    for src in tqdm(val_dataloader, leave=False):
        src = src.to(device)
        src_mask, src_padding_mask = create_masks(src, pad_idx=val_dataloader.dataset.pad_id, device=device)

        logits = model.forward(input_ids=src, attention_mask=src_padding_mask).logits[:, :-1, :]

        src_out = src[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), src_out.reshape(-1))
        losses += loss.item()

    acc, hitrate = calc_accuracy_and_hitrate_at_k(model, val_dataloader, device, k)

    return losses / len(val_dataloader), acc, hitrate

In [6]:
import torch.nn as nn
from torch.utils.data import DataLoader
from timeit import default_timer as timer
from transformers import Phi3Config, Phi3ForCausalLM

NUM_EPOCHS = 40
BATCH_SIZE = 64




config = Phi3Config(
    vocab_size=train_dataset.vocab_size,
    hidden_size=256,
    intermediate_size=1024,
    num_hidden_layers=2,
    num_attention_heads=4,
    original_max_position_embeddings=512,
    resid_pdrop=0.1,
    embd_pdrop=0.1,
    attention_dropout=0.1,
    max_position_embeddings=512,
    pad_token_id=train_dataset.pad_id,
    bos_token_id=train_dataset.bos_id,
    eos_token_id=train_dataset.eos_id,
    use_cache=False,
)

model = Phi3ForCausalLM(config)

torch.manual_seed(42)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

transformer = model.to(device)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)


loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train_dataset.pad_id, label_smoothing=0.07)

optimizer = torch.optim.AdamW(transformer.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.01)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, pin_memory=True)
k = 3

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer, loss_fn, train_loader, device)
    end_time = timer()
    val_loss, val_acc, val_hitrate = evaluate(transformer, loss_fn, val_loader, device, k=k)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, Val ACC: {val_acc:.3f}, Val hitrate@{k}: {val_hitrate:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

In [7]:
transformer.sa

In [8]:
calc_accuracy_and_hitrate_at_k(transformer, test_loader, device, 5)

  0%|          | 0/229 [00:00<?, ?it/s]

(0.7374354004859924, 0.9549716114997864)

In [18]:
sum(p.numel() for p in transformer.parameters())

4849895

In [13]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
total_seq_len_in_dataloader = 0
total_correct_predictions = 0
for batch in train_loader:
    src_mask, src_pad_mask = (create_masks(batch, pad_idx=0, device=device))
    src_mask = src_mask.to(device)
    src_pad_mask = src_pad_mask.to(device)
    batch = batch.to(device)
    is_bos = (batch == train_loader.dataset.bos_id)

    labels = batch[:, 1:]
    labels_pad_mask = (labels == 0)
    labels_without_pad = labels[~labels_pad_mask]
    logits_without_last = transformer(batch, src_mask, src_pad_mask)[:, :-1, :]
    predictions = logits_without_last.argmax(dim=-1)

    predictions_3 = torch.argsort(logits_without_last, dim=-1, descending=True)[:, :, :3][~labels_pad_mask]

    print(predictions_3 == labels_without_pad.unsqueeze(1))

    # predictions_without_pad = predictions[~labels_pad_mask]

    # total_seq_len_in_batch = labels_without_pad.shape[0]
    # total_predicions_len = predictions_without_pad.shape[0]

    # assert total_seq_len_in_batch == total_predicions_len

    # total_seq_len_in_dataloader += total_seq_len_in_batch

    # correct_predictions = (labels_without_pad == predictions_without_pad).float().sum()

    # total_correct_predictions += correct_predictions

    # print(total_correct_predictions)
    # print(total_seq_len_in_batch)

    
    break
    

tensor([[ True, False, False],
        [ True, False, False],
        [False, False, False],
        [False, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False,  True, False],
        [ True, False, False],
        [ True, False, False],
        [False,  True, False],
        [False,  True, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [ True, False, False],
        [False,  True, False],
        [ True, False, False],
        [ True, False, False],
        [False,  True, False]], device='cuda:0')


In [18]:
def compute_accuracy(logits: torch.Tensor, labels: torch.Tensor) -> float:
    """
    Compute the accuracy of predictions for an RNN language model.
    
    :param logits: Logits output by the model of shape [batch_size, sequence_length, vocab_size]
    :param labels: Ground truth labels of shape [batch_size, sequence_length]
    :return: Accuracy as a float
    """
    # Find the argmax of the logits along the last dimension to get the most likely token indices
    predictions = logits.argmax(dim=-1)
    
    # Compute the number of correct predictions
    correct_predictions = (predictions == labels).float().sum()
    
    # Calculate the accuracy
    accuracy = correct_predictions / labels.numel()
    
    return accuracy.item()

def calc_accuracy(model, loader) -> float:
    with torch.no_grad():
        model.eval()
        for indices, lengths in tqdm(loader, desc=tqdm_desc):
            optimizer.zero_grad()
            indices = indices[:, :lengths.max()].to(device)
            logits = model(indices[:, :-1], lengths - 1)
            loss = criterion(logits.transpose(1, 2), indices[:, 1:])
            loss.backward()
            optimizer.step()
        
            train_acc += compute_accuracy(logits, indices[:, 1:]) * indices.shape[0]
        
            train_loss += loss.item() * indices.shape[0]
        
        train_loss /= len(loader.dataset)
        train_acc = train_acc / len(loader.dataset)
        return train_loss, train_acc