In [None]:
%pip install transformers torch pandas
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2Model
import csv
from zipfile import ZipFile

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def unzip_data(zip_path="pi_deepset.zip", extract_to='datasets'):
    with ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

In [None]:
def load_data(directory):
    train_data = pd.read_parquet(os.path.join(directory, 'train.parquet'))
    val_data = pd.read_parquet(os.path.join(directory, 'validation.parquet'))
    test_data = pd.read_parquet(os.path.join(directory, 'test.parquet'))
    return train_data, val_data, test_data

In [None]:
class PromptDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.pairs[idx][0], self.pairs[idx][1], self.labels[idx]

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self, embedding_dim):
        super(SiameseNetwork, self).__init__()
        self.embedding_dim = embedding_dim
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(128, 1)

    def forward_once(self, x):
        # Ensure x is properly shaped [batch_size, seq_length, embedding_dim]
        x = x.unsqueeze(1)  # Add sequence length dimension if necessary
        _, (hidden, _) = self.lstm(x)
        return hidden.squeeze(0)

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        distance = torch.abs(output1 - output2)
        logits = self.fc(distance)
        return logits

In [None]:
def prepare_data(data, tokenizer, max_length=1024):
    tokenized = [tokenizer(text, padding='max_length', max_length=max_length, truncation=True, return_tensors='pt') for text in data]
    return torch.stack([t['input_ids'].squeeze(0) for t in tokenized])

def create_pairs(data, tokenizer):
    inputs = prepare_data(data['user_input'], tokenizer)
    labels = data['label'].values
    pairs = [(inputs[i], inputs[j]) for i in range(len(inputs)) for j in range(len(inputs)) if i != j]
    pair_labels = [1 if labels[i] == labels[j] else 0 for i in range(len(labels)) for j in range(len(labels)) if i != j]
    return pairs, pair_labels

In [None]:
def save_losses(losses, filename='training_loss.csv'):
    with open(filename, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Epoch', 'Loss'])
        writer.writerows(losses)

def save_model(model, filename='dansn.pth'):
    torch.save(model.state_dict(), filename)

In [None]:
def main():
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset_filepath = 'datasets/pi_deepset'
    unzip_data('/content/pi_deepset.zip')  # Path to your zip file
    train_data, val_data, _ = load_data(dataset_filepath)

    print("A")

    train_pairs, train_labels = create_pairs(train_data, tokenizer)
    val_pairs, val_labels = create_pairs(val_data, tokenizer)

    train_dataset = PromptDataset(train_pairs, train_labels)
    val_dataset = PromptDataset(val_pairs, val_labels)

    print("B")

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    print(len(train_loader))

    model = SiameseNetwork(1024).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss for binary classification

    losses = []
    for epoch in range(10):
        print("Waiting A")
        
        model.train()

        print("Waiting B")
        for input1, input2, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
            input1, input2, labels = input1.to(device).float(), input2.to(device).float(), labels.to(device).float().view(-1, 1)
            # print("Waiting C")
            optimizer.zero_grad()
            outputs = model(input1, input2)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

        print(f'Epoch {epoch}, Loss: {loss.item()}')
        losses.append([epoch, loss.item()])
    
    save_losses(losses)
    save_model(model)


if __name__ == '__main__':
    main()