In [1]:
from numba import cuda 
device = cuda.get_current_device()
device.reset()

In [2]:
from transformers import AutoModel, AutoTokenizer, AutoModel
import torch
import utils
from torch.utils.data import DataLoader
from datasets import load_dataset
import transformers
from torch.utils.data import Sampler
from collections import defaultdict
from datasets import DatasetDict

In [3]:
TRAIN_SET = 'train_law.csv'
VALIDATION_SET = 'validation_law.csv'

In [4]:
tokenizer = AutoTokenizer.from_pretrained('readerbench/jurBERT-base')

In [5]:
rows_train = utils.load_csv('train_law.csv')
rows_train = [row for row in rows_train if len(row[-1]) == 1]

rows_validation = utils.load_csv('validation_law.csv')
rows_validation = [row for row in rows_validation if len(row[-1]) == 1]

In [6]:
import re

def remove_number_from_string(input_string):
    result = re.sub(r'^\d+\s*', '', input_string)
    return result

In [7]:
train_set = load_dataset('csv', data_files=TRAIN_SET)['train']
validation_set = load_dataset('csv', data_files=VALIDATION_SET)['train']

In [8]:
# Removed 'context' column

train_set = train_set.remove_columns(column_names=['choice_index','context','bert_input','prompt'])
validation_set = validation_set.remove_columns(column_names=['choice_index','context','bert_input','prompt'])

In [9]:
train_set = train_set.rename_column(original_column_name='question_index', new_column_name='index')
validation_set = validation_set.rename_column(original_column_name='question_index', new_column_name='index')

In [10]:
validation_set

Dataset({
    features: ['index', 'question', 'choice', 'label'],
    num_rows: 3633
})

In [11]:
# TODO: Remove numbers, lowercase, tokenize

def tokenize(samples):
    index = samples['index']
    question = samples['question']
    choice = samples['choice']
    label = samples['label']
    
    letter = choice[0]
    q = remove_number_from_string(question)[1:].strip().lower()
    c = choice[2:].strip().lower()
    
    tokenized_samples = {}
    
    tokenized_question = tokenizer(q, padding=False, truncation=False, add_special_tokens=False)
    tokenized_choice = tokenizer(c, padding=False, truncation=False, add_special_tokens=False)
    
    # tokenized_samples['question_ids'] = tokenized_question['input_ids']
    for k,v in tokenized_question.items():
        tokenized_samples['question_' + k] = v
    
    for k,v in tokenized_choice.items():
        tokenized_samples['choice_' + k] = v
    
    #tokenized_samples['choice_ids'] = tokenized_choice['input_ids']
    
    tokenized_samples['index'] = index
    
    if letter == label:
        tokenized_samples['label'] = 1
    else:
        tokenized_samples['label'] = -1
    
    return tokenized_samples

In [12]:
train_encoded = train_set.map(tokenize, batched=False, remove_columns=['question', 'choice'])
validation_encoded = validation_set.map(tokenize, batched=False, remove_columns=['question', 'choice'])

In [13]:
class GroupedByIndexSampler(Sampler):
    def __init__(self, data_source, shuffle=False):
        self.data_source = data_source
        self.shuffle = shuffle

        # Group indices by the "index" column value
        self.index_groups = defaultdict(list)
        for idx, item in enumerate(data_source):
            self.index_groups[item['index']].append(idx)

        # Convert the defaultdict to a list of index groups
        self.groups = list(self.index_groups.values())

    def __iter__(self):
        # Shuffle the groups if you want (optional)
        if not self.shuffle:
            torch.manual_seed(0)  # For reproducibility
        indices = torch.randperm(len(self.groups)).tolist()
        for i in indices:
            yield self.groups[i]

    def __len__(self):
        return len(self.groups)

In [14]:
def collate_fn(samples):
    max_length_questions = max([len(sample['question_input_ids']) for sample in samples])
    max_length_choice = max([len(sample['choice_input_ids']) for sample in samples])
    
    for i, sample in enumerate(samples):
        question_input_ids = sample['question_input_ids']
        question_token_type_ids = sample['question_token_type_ids']
        question_attention_mask = sample['question_attention_mask']
        
        choice_input_ids = sample['choice_input_ids']
        choice_token_type_ids = sample['choice_token_type_ids']
        choice_attention_mask = sample['choice_attention_mask']
        
        question_input_ids = question_input_ids + [0] * (max_length_questions - len(question_input_ids))
        question_token_type_ids = question_token_type_ids + [0] * (max_length_questions - len(question_token_type_ids))
        question_attention_mask = question_attention_mask + [0] * (max_length_questions - len(question_attention_mask))
        
        choice_input_ids = choice_input_ids + [0] * (max_length_choice - len(choice_input_ids))
        choice_token_type_ids = choice_token_type_ids + [0] * (max_length_choice - len(choice_token_type_ids))
        choice_attention_mask = choice_attention_mask + [0] * (max_length_choice - len(choice_attention_mask))
        
        samples[i]['question_input_ids'] = question_input_ids
        samples[i]['question_token_type_ids'] = question_token_type_ids
        samples[i]['question_attention_mask'] = question_attention_mask
        
        samples[i]['choice_input_ids'] = choice_input_ids
        samples[i]['choice_token_type_ids'] = choice_token_type_ids
        samples[i]['choice_attention_mask'] = choice_attention_mask
    
    collated_samples = {
        'question_input_ids': [],
        'question_token_type_ids': [],
        'question_attention_mask': [],
        'choice_input_ids': [],
        'choice_token_type_ids': [],
        'choice_attention_mask': [],
        'label': [],
        'index': []
    }

    for key, l in collated_samples.items():
        for sample in samples:
            l.append(sample[key])
        collated_samples[key] = torch.tensor(collated_samples[key])
    
    return collated_samples

In [15]:
sampler = GroupedByIndexSampler(validation_encoded)
validation_dataloader = DataLoader(validation_encoded, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=False)
train_dataloader = DataLoader(train_encoded, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=False)

In [16]:
def col_batch(batch):
    cbatch = {
        'question' : {
            'input_ids': [],
            'attention_mask': [],
            'token_type_ids': []
        },
        'choice' : {
            'input_ids': [],
            'attention_mask': [],
            'token_type_ids': []
        },
        'label': [],
        'index': []
    }
    
    l = len(batch['label'])
    
    cbatch['question']['input_ids'] = batch['question_input_ids']
    cbatch['question']['attention_mask'] = batch['question_attention_mask']
    cbatch['question']['token_type_ids'] = batch['question_token_type_ids']
    
    cbatch['choice']['input_ids'] = batch['choice_input_ids']
    cbatch['choice']['attention_mask'] = batch['choice_attention_mask']
    cbatch['choice']['token_type_ids'] = batch['choice_token_type_ids']
    
    cbatch['question']['input_ids'] = torch.cat((torch.tensor([[tokenizer.cls_token_id]] * l), cbatch['question']['input_ids']), dim=-1).int()
    cbatch['question']['attention_mask'] = torch.cat((torch.tensor([[1]] * l), cbatch['question']['attention_mask']), dim=-1).int()
    cbatch['question']['token_type_ids'] = torch.cat((torch.tensor([[0]] * l), cbatch['question']['token_type_ids']), dim=-1).int()
    
    cbatch['choice']['input_ids'] = torch.cat((torch.tensor([[tokenizer.cls_token_id]] * l), cbatch['choice']['input_ids']), dim=-1).int()
    cbatch['choice']['attention_mask'] = torch.cat((torch.tensor([[1]] * l), cbatch['choice']['attention_mask']), dim=-1).int()
    cbatch['choice']['token_type_ids'] = torch.cat((torch.tensor([[0]] * l), cbatch['choice']['token_type_ids']), dim=-1).int()
    
    cbatch['label'] = batch['label'].int()
    cbatch['index'] = batch['index'].int()
    
    return cbatch

In [17]:
import gc
def clean_batch(batch):
    if type(batch) != dict:
        del batch
        #gc.collect()
        #torch.cuda.empty_cache()
        return
    
    for _,v in batch.items():
        clean_batch(v)

In [18]:
# from cross_bert import CrossBERT
# from col_bert import ColBERT
import torch.nn as nn

class ColBERT(nn.Module):
    def __init__(
        self,
        bert1,
        bert2,
        similarity=True,
        *args, 
        **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.bert1 = bert1
        self.bert2 = bert2
        
        self.cos = nn.CosineSimilarity()
        self.loss_fn = nn.CosineEmbeddingLoss()

    def forward(self, question, choice, label):
        out1 = self.bert1(**question)
        out1 = out1.last_hidden_state
        out1 = out1[:, 0]
        
        # if self.siamese:
        #     out2 = self.bert1(**choice)
        #     out2 = out2.last_hidden_state
        #     out2 = out2[:, 0]
        # else:
        out2 = self.bert2(**choice)
        out2 = out2.last_hidden_state
        out2 = out2[:, 0]
        
        # with torch.no_grad():
        #     out = self.cos(out1, out2)
        
        loss = self.loss_fn(out1, out2, label)
        
        # out1 = out1.detach().cpu()
        # out2 = out2.detach().cpu()
        # del out1
        # del out2
        # for _ in range(1):
        #     gc.collect()
        #     torch.cuda.empty_cache()
        
        return loss

bert1 = AutoModel.from_pretrained('readerbench/jurBERT-base').cuda()
bert2 = AutoModel.from_pretrained('readerbench/jurBERT-base').cuda()
model = ColBERT(bert1, bert2)

In [19]:
import torch.nn as nn

EPOCHS = 30
lr = 1e-4
#optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
loss_fn = nn.CosineEmbeddingLoss()

In [20]:
# for batch in dataloader:
#     cbatch = col_batch(batch)
    
#     label = cbatch['label']
#     question = cbatch['question']
#     choice = cbatch['choice']
#     print(model(question, choice))
#     break
import gc
from tqdm import tqdm
epoch_train_loss = []
epoch_eval_loss = []
step_train_loss = []
step_eval_loss = []

for epoch in range(EPOCHS):
    
    print('EPOCH', epoch)
    print('TRAIN')
    train_loss = 0
    model.train()
    for batch in tqdm(train_dataloader, total=len(train_dataloader)):
        cbatch = col_batch(batch)
        
        for k, v in cbatch['question'].items():
            cbatch['question'][k] = v.cuda()
        
        for k, v in cbatch['choice'].items():
            cbatch['choice'][k] = v.cuda()
        
        cbatch['label'] = cbatch['label'].cuda()
        
        # label = cbatch['label']
        # question = cbatch['question']
        # choice = cbatch['choice']
        
        out1 = bert1(**cbatch['question']).last_hidden_state[:, 0]
        #out2 = bert2(**cbatch['choice']).last_hidden_state[:, 0]
        
        del out1
        #del out2
        
        bert1.zero_grad()
        #bert2.zero_grad()
        #loss = loss_fn(out1, out2, cbatch['label'])
        
        #loss = model(cbatch['question'], cbatch['choice'], cbatch['label'])
        # loss = loss_fn(out1, out2, cbatch['label'])
        
        # out1 = model1(**cbatch['choice']).last_hidden_state[:, 0]
        # out2 = model2(**cbatch['question']).last_hidden_state[:, 0]
        
        del cbatch['question']['input_ids']
        del cbatch['question']['token_type_ids']
        del cbatch['question']['attention_mask']
        
        del cbatch['choice']['input_ids']
        del cbatch['choice']['token_type_ids']
        del cbatch['choice']['attention_mask']
        
        del cbatch['label']
        del cbatch['index']
        
        del batch
        del cbatch
        
        #loss = loss.detach().cpu()
        # del out1
        # del out2
        #del loss

        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        
        # loss = loss.detach().cpu()
        # out1 = out1.detach().cpu()
        # out2 = out2.detach().cpu()
        
        # for k, v in cbatch['question'].items():
        #     cbatch['question'][k] = v.cpu()
        
        # for k, v in cbatch['choice'].items():
        #     cbatch['choice'][k] = v.cpu()
        
        # cbatch['label'] = cbatch['label'].cpu()

       # step_train_loss.append(loss.cpu().item())
        #train_loss += step_train_loss[-1]
        # clean_batch(cbatch)
        # clean_batch(batch)
        
        # del loss
        # del out1
        # del out2
        # del label
        # del question
        # del choice
        for _ in range(1):
            gc.collect()
            torch.cuda.empty_cache()
    
    #train_loss = train_loss / len(train_dataloader)
    #epoch_train_loss.append(train_loss)
    
    print('EVAL')
    eval_loss = 0
    model.eval()
    for batch in tqdm(validation_dataloader, total=len(validation_dataloader)):
        cbatch = col_batch(batch)
        
        for k, v in cbatch['question'].items():
            cbatch['question'][k] = v.cuda()
        
        for k, v in cbatch['choice'].items():
            cbatch['choice'][k] = v.cuda()
        
        cbatch['label'] = cbatch['label'].cuda()
        
        with torch.no_grad():
            label = cbatch['label']
            question = cbatch['question']
            choice = cbatch['choice']
            out1, out2 = model(question, choice, loss)
            loss = loss_fn(out1, out2, label)
        
        #step_eval_loss.append(loss.cpu().item())
        #eval_loss += step_eval_loss[-1]
        # clean_batch(cbatch)
        # clean_batch(batch)
        
        #for k, v in cbatch['question'].items():
        del cbatch['question']['input_ids']
        del cbatch['question']['token_type_ids']
        del cbatch['question']['attention_mask']
        
        del cbatch['choice']['input_ids']
        del cbatch['choice']['token_type_ids']
        del cbatch['choice']['attention_mask']
        
        #for k, v in cbatch['choice'].items():
        
        del cbatch['label']
        del cbatch['index']
        
        del batch
        del cbatch
        del loss
        del out1
        del out2
        del label
        del question
        del choice
        for _ in range(1):
            gc.collect()
            torch.cuda.empty_cache()
    
    #eval_loss = eval_loss / len(validation_dataloader)
    #epoch_eval_loss.append(eval_loss)

EPOCH 0
TRAIN


 14%|█▎        | 40/295 [00:38<04:05,  1.04it/s]


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
