In [7]:
from datetime import datetime
import torch
import jsonlines
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, Subset
# from torch.utils.tensorboard import SummaryWriter

DATASET_PATH = "kaggle_preprocessed.jsonl"
MAX_LENGTH = 512
N_FEATURES = 28

LEVELS = ["A1", "A2", "B1", "B2", "C1", "C2"]

class SimpleWikiDataset(Dataset):
    def __init__(self, tokenizer, max_length):
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.texts = list(jsonlines.open(DATASET_PATH).iter())

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        tokenized = tokenizer(self.texts[index]["text"],
                            return_tensors='pt',
                            padding='max_length', max_length=self.max_length,
                           truncation=True)
        inputs = {}
        inputs["input_ids"] = torch.squeeze(tokenized["input_ids"])
        inputs["attention_mask"] = torch.squeeze(tokenized["attention_mask"])
        inputs["features"] = torch.tensor(self.texts[index]["features"])
        cefr_label = torch.tensor(LEVELS.index(self.texts[index]["label"]))

        return inputs, cefr_label

class CEFRClassifier(nn.Module):
    def __init__(self, num_cefr_levels):
        super(CEFRClassifier, self).__init__()
        
        self.bert = AutoModel.from_pretrained("distilbert/distilbert-base-uncased")

        # Freeze distilBERT params
        for param in self.bert.parameters():
            param.requires_grad = False
        
        self.fc1 = nn.Linear(768, 768)
        self.fc2 = nn.Linear(768, 128)
        self.fc3 = nn.Linear(28, 28)
        self.output = nn.Linear(128 + 28, num_cefr_levels)

        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
    
    def forward(self, input_ids, attention_mask, aux_features):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = bert_output.last_hidden_state
        pooled_output = sequence_output[:, 0]

        x = self.fc1(pooled_output)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        y = self.fc3(aux_features)
        y = self.relu(y)
        y = self.dropout(y)
        
        return self.output(torch.cat((x,y), -1))

device = "cuda"
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
dataset = SimpleWikiDataset(tokenizer, MAX_LENGTH)

model = CEFRClassifier(len(LEVELS))
model.to(device)

generator = torch.Generator().manual_seed(42)
train_set, validation_set, test_set = torch.utils.data.random_split(dataset, [0.7,0.2,0.1], generator=generator)

bs = 16

training_loader = DataLoader(train_set, batch_size=bs, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=bs, shuffle=False)

# Training
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

def train_step(epoch_idx):
    running_loss = 0.0
    last_loss = 0.0

    for i, train_data in enumerate(training_loader):
        train_inputs, train_labels = train_data

        input_ids = train_inputs["input_ids"].to(device)
        attention_mask = train_inputs["attention_mask"].to(device)
        aux_features = train_inputs["features"].to(device)
        train_labels = train_labels.to(device)
        
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_mask, aux_features)
        loss = loss_fn(logits, train_labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

        if i % 10 == 9:
            last_loss = running_loss / 10
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_idx * len(training_loader) + i + 1
            running_loss = 0.0

    return last_loss

EPOCHS = 30
best_vloss = 1_000_000.0
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch + 1))

    model.train(True)
    avg_loss = train_step(epoch)

    running_vloss = 0.0
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, validation_data in enumerate(validation_loader):
            validation_inputs, validation_labels = validation_data

            input_ids = validation_inputs["input_ids"].to(device)
            attention_mask = validation_inputs["attention_mask"].to(device)
            aux_features = validation_inputs["features"].to(device)
            validation_labels = validation_labels.to(device)
         
            validation_logits = model(input_ids, attention_mask, aux_features)
            validation_loss = loss_fn(validation_logits, validation_labels)
                
            running_vloss += validation_loss

    avg_vloss = running_vloss / len(validation_loader)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'runs/model_{}_{}'.format(timestamp, epoch)
        torch.save(model.state_dict(), model_path)

EPOCH 1:
  batch 10 loss: 1.7713922142982483
  batch 20 loss: 1.771324133872986
  batch 30 loss: 1.6771887421607972
  batch 40 loss: 1.5675775766372682
  batch 50 loss: 1.4587822437286377
  batch 60 loss: 1.3937555432319642
LOSS train 1.3937555432319642 valid 1.2478435039520264
EPOCH 2:
  batch 10 loss: 1.3313370108604432
  batch 20 loss: 1.191392707824707
  batch 30 loss: 1.172075092792511
  batch 40 loss: 1.186892330646515
  batch 50 loss: 1.102652096748352
  batch 60 loss: 1.1775652527809144
LOSS train 1.1775652527809144 valid 1.0461429357528687
EPOCH 3:
  batch 10 loss: 1.096544075012207
  batch 20 loss: 1.0057056665420532
  batch 30 loss: 0.9921921789646149
  batch 40 loss: 1.12746462225914
  batch 50 loss: 1.1043359398841859
  batch 60 loss: 1.0939607203006745
LOSS train 1.0939607203006745 valid 1.0117326974868774
EPOCH 4:
  batch 10 loss: 1.0843705415725708
  batch 20 loss: 0.9726447820663452
  batch 30 loss: 0.9289063811302185
  batch 40 loss: 1.0519471645355225
  batch 50 loss

KeyboardInterrupt: 

In [10]:
eval_model = CEFRClassifier(len(LEVELS))
eval_model.load_state_dict(torch.load("runs/model_20240712_092923_9"))

bs = 16
test_loader = DataLoader(test_set, batch_size=bs, shuffle=True)

device = "cuda"

def compute_accuracy(big_idx, targets):
    return (big_idx==targets).sum().item()

def validate(model, testing_loader):
    model.eval()
    model.to(device)
    
    n_correct = 0
    n_wrong = 0
    total_loss = 0
    nb_tr_examples = 0
    nb_tr_steps = 0
    
    with torch.no_grad():
        for i, data in enumerate(testing_loader):
            test_inputs, test_labels = data

            ids = test_inputs['input_ids'].to(device, dtype = torch.long)
            mask = test_inputs['attention_mask'].to(device, dtype = torch.long)
            aux_features = test_inputs["features"].to(device)
            targets = test_labels.to(device, dtype = torch.long)
            outputs = model(ids, mask, aux_features).squeeze()
            
            loss = loss_fn(outputs, targets)
            total_loss += loss.item()
            
            big_val, big_idx = torch.max(outputs.data, dim=1)
            n_correct += compute_accuracy(big_idx, targets)

            nb_tr_examples+=targets.size(0)
            nb_tr_steps += 1
            
            if i%100==0:
                loss_step = total_loss/(i+1)
                accu_step = (n_correct*100)/nb_tr_examples
                print(f"Validation Loss per 100 steps: {loss_step}")
                print(f"Validation Accuracy per 100 steps: {accu_step}")
                
    epoch_loss = total_loss/nb_tr_steps
    epoch_accu = (n_correct*100)/nb_tr_examples
    print(f"Validation Loss Epoch: {epoch_loss}")
    print(f"Validation Accuracy Epoch: {epoch_accu}")
    
    return epoch_accu

validate(eval_model, test_loader)

Validation Loss per 100 steps: 0.8031296133995056
Validation Accuracy per 100 steps: 68.75
Validation Loss Epoch: 0.838558030128479
Validation Accuracy Epoch: 66.44295302013423


66.44295302013423