In [1]:
from preprocessing import *




In [2]:
from transformers import LEDTokenizer, LEDForConditionalGeneration
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW

In [3]:
tokenizer = LEDTokenizer.from_pretrained("allenai/led-base-16384")
model = LEDForConditionalGeneration.from_pretrained("allenai/led-base-16384", gradient_checkpointing=True, use_cache=False)
file_path = 'new_court_cases.csv'
# Prepare the data, model, and tokenizer before training
preprocessor = preprocess(file_path, tokenizer, model)
ready_model, ready_tokenizer, ready_data = preprocessor.return_model_tokenizer_data()

  return self.fget.__get__(instance, owner)()


In [4]:
class dataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


In [5]:
class modelling:
    def __init__(self, model, tokenizer, data, epochs = 3):
        self.data = data
        self.model = model
        self.epochs = epochs
        self.tokenizer = tokenizer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Move the model to the GPU
        self.model.to(self.device)

        self.optimizer = AdamW(self.model.parameters(), lr=5e-5)  # Adjust learning rate as needed
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)  # Ignore padding tokens in loss

        # Instantiate the dataset and DataLoader
        self.dataset = dataset(data=self.data)  # Replace with actual preprocessed data
        self.dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True)  # Batch size can be adjusted based on your memory capacity

    def finetune(self):
        self.model.train()  # Put model in training mode
    
        for self.epoch in range(self.epochs):
            total_loss = 0
            for batch in self.dataloader:
                # Move input data to the GPU
                input_ids = batch["input_ids"].squeeze(1).to(self.device)
                attention_mask = batch["attention_mask"].squeeze(1).to(self.device)
                global_attention_mask = batch["global_attention_mask"].squeeze(1).to(self.device)
                decoder_input_ids = batch["decoder_input_ids"].squeeze(1).to(self.device)
                labels = batch["labels"].squeeze(1).to(self.device)

                print(f"input_ids.shape: {input_ids.shape}")
                print(f"attention_mask.shape: {attention_mask.shape}")
                print(f"global_attention_mask.shape: {global_attention_mask.shape}")
                print(f"labels.shape: {labels.shape}")
                print(f"decoder: {self.model.config.max_decoder_position_embeddings}") 
                print(f"Max input_ids: {input_ids.max()}")
                print(f"Max labels: {labels.max()}")
                print(f"Model vocab size: {self.model.config.vocab_size}")
    
                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    global_attention_mask=global_attention_mask,
                    labels=labels
                )
    
                logits = outputs.logits  # [batch_size, sequence_length, vocab_size]
                
                # Reshape logits and labels for CrossEntropyLoss
                logits = logits.view(-1, logits.size(-1))  # [batch_size * sequence_length, vocab_size]
                labels = labels.view(-1)  # [batch_size * sequence_length]
    
                # Compute the loss
                loss = self.criterion(logits, labels)
                total_loss += loss.item()
    
                # Backward pass and optimization
                self.optimizer.zero_grad()  # Clear previous gradients
                loss.backward()  # Backpropagation
                self.optimizer.step()  # Update model parameters
    
            print(f"Epoch {self.epoch + 1}/{self.epochs}, Loss: {total_loss / len(self.dataloader)}")
        

In [6]:
modeller = modelling(ready_model, ready_tokenizer, ready_data)



In [7]:
modeller.finetune()

input_ids.shape: torch.Size([2, 16384])
attention_mask.shape: torch.Size([2, 16384])
global_attention_mask.shape: torch.Size([2, 16384])
labels.shape: torch.Size([2, 16384])
decoder:  1024
Max input_ids: 49072
Max labels: 49072
Model vocab size: 50265


KeyboardInterrupt: 