In [None]:
# prompt: connect to drive to the juro/Qbert directory

from google.colab import drive
drive.mount('/content/drive')
%cd '/content/drive/My Drive/juro/QBert'


Mounted at /content/drive
/content/drive/My Drive/juro/QBert


In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoModel

In [None]:
class QBERT(nn.Module):

    def __init__(
        self,
        variant='readerbench/jurBERT-base',
        no_answers=3,
        *args,
        **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(variant)
        self.bert = AutoModel.from_pretrained(variant)
        self.bert.to(self.device)

        self.embedding_size = self.bert.config.hidden_size
        print(self.embedding_size)
        self.mlp = nn.Sequential(
            nn.Linear(self.embedding_size, 1),
            nn.Sigmoid()
        )
        self.mlp.to(self.device)

    def forward(self, q, a):
        qa_pair = torch.cat((q, a), dim=-1)
        out = self.bert(qa_pair).last_hidden_state[:, 0, :]
        out = self.mlp(out)
        return out

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoModel
import torch
import utils
from torch.utils.data import DataLoader
from datasets import load_dataset
import transformers
from torch.utils.data import Sampler
from collections import defaultdict
from datasets import DatasetDict

In [None]:
TRAIN_SET = 'train_law.csv'
VALIDATION_SET = 'validation_law.csv'

tokenizer = AutoTokenizer.from_pretrained('readerbench/jurBERT-base')

rows_train = utils.load_csv('train_law.csv')
rows_train = [row for row in rows_train if len(row[-1]) == 1]

rows_validation = utils.load_csv('validation_law.csv')
rows_validation = [row for row in rows_validation if len(row[-1]) == 1]

In [None]:
import re

def remove_number_from_string(input_string):
    result = re.sub(r'^\d+\s*', '', input_string)
    return result

In [None]:
train_set = load_dataset('csv', data_files=TRAIN_SET)['train']
validation_set = load_dataset('csv', data_files=VALIDATION_SET)['train']

# Removed 'context' column
train_set = train_set.remove_columns(column_names=['choice_index','context','bert_input','prompt'])
validation_set = validation_set.remove_columns(column_names=['choice_index','context','bert_input','prompt'])

train_set = train_set.rename_column(original_column_name='question_index', new_column_name='index')
validation_set = validation_set.rename_column(original_column_name='question_index', new_column_name='index')


In [None]:

# TODO: Remove numbers, lowercase, tokenize

def tokenize(samples):
    index = samples['index']
    question = samples['question']
    choice = samples['choice']
    label = samples['label']

    letter = choice[0]
    q = remove_number_from_string(question)[1:].strip().lower()
    c = choice[2:].strip().lower()

    tokenized_samples = {}

    tokenized_question = tokenizer(q, padding=False, truncation=False, add_special_tokens=False)
    tokenized_choice = tokenizer(c, padding=False, truncation=False, add_special_tokens=False)

    # tokenized_samples['question_ids'] = tokenized_question['input_ids']
    for k,v in tokenized_question.items():
        tokenized_samples['question_' + k] = v

    for k,v in tokenized_choice.items():
        tokenized_samples['choice_' + k] = v

    #tokenized_samples['choice_ids'] = tokenized_choice['input_ids']

    tokenized_samples['index'] = index

    if letter == label:
        tokenized_samples['label'] = 1
    else:
        tokenized_samples['label'] = 0

    return tokenized_samples

In [None]:
train_encoded = train_set.map(tokenize, batched=False, remove_columns=['question', 'choice'])
validation_encoded = validation_set.map(tokenize, batched=False, remove_columns=['question', 'choice'])

Map:   0%|          | 0/25086 [00:00<?, ? examples/s]

Map:   0%|          | 0/3633 [00:00<?, ? examples/s]

In [None]:

class GroupedByIndexSampler(Sampler):
    def __init__(self, data_source, shuffle=False):
        self.data_source = data_source
        self.shuffle = shuffle

        # Group indices by the "index" column value
        self.index_groups = defaultdict(list)
        for idx, item in enumerate(data_source):
            self.index_groups[item['index']].append(idx)

        # Convert the defaultdict to a list of index groups
        self.groups = list(self.index_groups.values())

    def __iter__(self):
        # Shuffle the groups if you want (optional)
        if not self.shuffle:
            torch.manual_seed(0)  # For reproducibility
        indices = torch.randperm(len(self.groups)).tolist()
        for i in indices:
            yield self.groups[i]

    def __len__(self):
        return len(self.groups)

# %%
def collate_fn(samples):
    max_length_questions = max([len(sample['question_input_ids']) for sample in samples])
    max_length_choice = max([len(sample['choice_input_ids']) for sample in samples])

    for i, sample in enumerate(samples):
        question_input_ids = sample['question_input_ids']
        question_token_type_ids = sample['question_token_type_ids']
        question_attention_mask = sample['question_attention_mask']

        choice_input_ids = sample['choice_input_ids']
        choice_token_type_ids = sample['choice_token_type_ids']
        choice_attention_mask = sample['choice_attention_mask']

        question_input_ids = question_input_ids + [0] * (max_length_questions - len(question_input_ids))
        question_token_type_ids = question_token_type_ids + [0] * (max_length_questions - len(question_token_type_ids))
        question_attention_mask = question_attention_mask + [0] * (max_length_questions - len(question_attention_mask))

        choice_input_ids = choice_input_ids + [0] * (max_length_choice - len(choice_input_ids))
        choice_token_type_ids = choice_token_type_ids + [0] * (max_length_choice - len(choice_token_type_ids))
        choice_attention_mask = choice_attention_mask + [0] * (max_length_choice - len(choice_attention_mask))

        samples[i]['question_input_ids'] = question_input_ids
        samples[i]['question_token_type_ids'] = question_token_type_ids
        samples[i]['question_attention_mask'] = question_attention_mask

        samples[i]['choice_input_ids'] = choice_input_ids
        samples[i]['choice_token_type_ids'] = choice_token_type_ids
        samples[i]['choice_attention_mask'] = choice_attention_mask

    collated_samples = {
        'question_input_ids': [],
        'question_token_type_ids': [],
        'question_attention_mask': [],
        'choice_input_ids': [],
        'choice_token_type_ids': [],
        'choice_attention_mask': [],
        'label': [],
        'index': []
    }

    for key, l in collated_samples.items():
        for sample in samples:
            l.append(sample[key])
        collated_samples[key] = torch.tensor(collated_samples[key])

    return collated_samples

sampler = GroupedByIndexSampler(validation_encoded)
validation_dataloader = DataLoader(validation_encoded, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=False)
train_dataloader = DataLoader(train_encoded, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=False)

def col_batch(batch):
    cbatch = {
        'question' : {
            'input_ids': [],
            'attention_mask': [],
            'token_type_ids': []
        },
        'choice' : {
            'input_ids': [],
            'attention_mask': [],
            'token_type_ids': []
        },
        'label': [],
        'index': []
    }

    l = len(batch['label'])

    cbatch['question']['input_ids'] = batch['question_input_ids']
    cbatch['question']['attention_mask'] = batch['question_attention_mask']
    cbatch['question']['token_type_ids'] = batch['question_token_type_ids']

    cbatch['choice']['input_ids'] = batch['choice_input_ids']
    cbatch['choice']['attention_mask'] = batch['choice_attention_mask']
    cbatch['choice']['token_type_ids'] = batch['choice_token_type_ids']

    cbatch['question']['input_ids'] = torch.cat((torch.tensor([[tokenizer.cls_token_id]] * l), cbatch['question']['input_ids']), dim=-1).int()
    cbatch['question']['attention_mask'] = torch.cat((torch.tensor([[1]] * l), cbatch['question']['attention_mask']), dim=-1).int()
    cbatch['question']['token_type_ids'] = torch.cat((torch.tensor([[0]] * l), cbatch['question']['token_type_ids']), dim=-1).int()

    cbatch['choice']['input_ids'] = torch.cat((torch.tensor([[tokenizer.sep_token_id]] * l), cbatch['choice']['input_ids']), dim=-1).int()
    cbatch['choice']['attention_mask'] = torch.cat((torch.tensor([[1]] * l), cbatch['choice']['attention_mask']), dim=-1).int()
    cbatch['choice']['token_type_ids'] = torch.cat((torch.tensor([[0]] * l), cbatch['choice']['token_type_ids']), dim=-1).int()

    cbatch['choice']['input_ids'] = torch.cat((cbatch['choice']['input_ids'], torch.tensor([[tokenizer.sep_token_id]] * l)), dim=-1).int()
    cbatch['choice']['attention_mask'] = torch.cat((cbatch['choice']['attention_mask'], torch.tensor([[1]] * l)), dim=-1).int()
    cbatch['choice']['token_type_ids'] = torch.cat((cbatch['choice']['token_type_ids'], torch.tensor([[0]] * l)), dim=-1).int()

    cbatch['label'] = batch['label'].int()
    cbatch['index'] = batch['index'].int()

    return cbatch

import gc
def clean_batch(batch):
    if type(batch) != dict:
        del batch
        #gc.collect()
        #torch.cuda.empty_cache()
        return

    for _,v in batch.items():
        clean_batch(v)





In [None]:
model = QBERT()

768


In [None]:
import torch.nn as nn


EPOCHS = 30
lr = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()

# %%
# for batch in dataloader:
#     cbatch = col_batch(batch)

#     label = cbatch['label']
#     question = cbatch['question']
#     choice = cbatch['choice']
#     print(model(question, choice))
#     break
import gc
from tqdm import tqdm
epoch_train_loss = []
epoch_eval_loss = []
step_train_loss = []
step_eval_loss = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(EPOCHS):

    print('EPOCH', epoch)
    print('TRAIN')
    train_loss = 0
    model.train()
    for batch in tqdm(train_dataloader, total=len(train_dataloader)):
        cbatch = col_batch(batch)

        for k, v in cbatch['question'].items():
            cbatch['question'][k] = v.to(device)

        for k, v in cbatch['choice'].items():
            cbatch['choice'][k] = v.to(device)

        cbatch['label'] = cbatch['label'].float().to(device)

        out1 = model(cbatch['question']['input_ids'], cbatch['choice']['input_ids']).squeeze()
        loss = loss_fn(out1, cbatch['label'])

        del out1
        del cbatch['question']['input_ids']
        del cbatch['question']['token_type_ids']
        del cbatch['question']['attention_mask']

        del cbatch['choice']['input_ids']
        del cbatch['choice']['token_type_ids']
        del cbatch['choice']['attention_mask']

        del cbatch['label']
        del cbatch['index']

        del batch
        del cbatch

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step_train_loss.append(loss.item())
        train_loss += step_train_loss[-1]

        for _ in range(1):
            gc.collect()
            torch.cuda.empty_cache()

    train_loss = train_loss / len(train_dataloader)
    epoch_train_loss.append(train_loss)

    print(f"Loss training epoch {epoch + 1}: {train_loss}")

    with open('train_loss.txt', 'a') as f:
        f.write(f"Loss training epoch {epoch + 1}: {train_loss}\n")
        f.write(f"Individual loss {epoch + 1}: {step_train_loss}\n")

    print('EVAL')
    eval_loss = 0
    model.eval()
    for batch in tqdm(validation_dataloader, total=len(validation_dataloader)):
        cbatch = col_batch(batch)

        for k, v in cbatch['question'].items():
            cbatch['question'][k] = v.to(device)

        for k, v in cbatch['choice'].items():
            cbatch['choice'][k] = v.to(device)

        cbatch['label'] = cbatch['label'].to(device)

        with torch.no_grad():
            label = cbatch['label'].float()
            question = cbatch['question']['inputs_ids']
            choice = cbatch['choice']['inpus_ids']
            out1 = model(question, choice).squeeze()
            loss = loss_fn(out1, label)

        step_eval_loss.append(loss.item())
        eval_loss += step_eval_loss[-1]

        #for k, v in cbatch['question'].items():
        del cbatch['question']['input_ids']
        del cbatch['question']['token_type_ids']
        del cbatch['question']['attention_mask']

        del cbatch['choice']['input_ids']
        del cbatch['choice']['token_type_ids']
        del cbatch['choice']['attention_mask']

        #for k, v in cbatch['choice'].items():

        del cbatch['label']
        del cbatch['index']

        del batch
        del cbatch
        del loss
        del out1
        del out2
        del label
        del question
        del choice
        for _ in range(1):
            gc.collect()
            torch.cuda.empty_cache()

    eval_loss = eval_loss / len(validation_dataloader)
    epoch_eval_loss.append(eval_loss)
    print(f"Loss testing epoch {epoch + 1}: {eval_loss}")

    with open('test_loss.txt', 'a') as f:
        f.write(f"Loss testing epoch {epoch + 1}: {eval_loss}\n")
        f.write(f"Individual loss {epoch + 1}: {epoch_eval_loss}\n")

torch.save(self.model.state_dict(), "qbert_model.plt")


EPOCH 0
TRAIN


  1%|▏         | 4/295 [00:59<1:24:50, 17.49s/it]