In [None]:
# Define Dataset
class MedDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe: DatasetDict, tokenizer, max_length: int = 512, split: str = 'train'):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False).fit(np.stack(self.dataframe['train']['label']).reshape(-1, 1))
        self.labels = self.enc.transform(np.stack(self.dataframe[split]['label']).reshape(-1, 1))
        self.encodings = self.tokenizer(self.dataframe[split]['text'], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')

    def __getitem__(self, idx):
        item = {key: (val[idx].clone().detach()) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
# Define Dataset
class MedDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe: DatasetDict, tokenizer, max_length: int = 512, split: str = 'train'):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False).fit(np.stack(self.dataframe['train']['label']).reshape(-1, 1))
        self.labels = self.enc.transform(np.stack(self.dataframe[split]['label']).reshape(-1, 1))
        self.encodings = self.tokenizer(self.dataframe[split]['text'], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')

    def __getitem__(self, idx):
        item = {key: (val[idx].clone().detach()) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
# Create Datasets
train_dataset = MedDataset(dataset, tokenizer, split='train')
val_dataset = MedDataset(dataset, tokenizer, split='validation')
test_dataset = MedDataset(dataset, tokenizer, split='test')

In [None]:
# Pytorch Implementation

# # Set only specific layers to be trainable
# for param in model.base_model.parameters():
#     param.requires_grad = False

# # Set only specific layers to be trainable
# for param in model.classifier.parameters():
#     param.requires_grad = True

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)

# Optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# Loss
def loss_fn(logits, targets):
    loss = (torch.nn.CrossEntropyLoss()(logits, targets) + 
            torchmetrics.classification.MulticlassF1Score(num_classes=num_labels, average='weighted').to(device)(logits, targets))
    return loss

# GPU Memory optimization
model.gradient_checkpointing_enable()
accelerator = Accelerator(fp16=True)
model, optimizer, train_loader, val_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, val_loader, test_loader)

In [None]:
# Training
epochs = 36

for epoch in range(epochs):
    pbar = tqdm.tqdm(train_loader)
    for i, batch in enumerate(pbar):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = loss_fn(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        # Logging Progress
        if i % 10 == 0:
            pbar.set_description(f"Epoch {epoch} training loss: {loss.item()}")
    
    # Evaluate on Validation
    val_CE_loss = []
    val_f1 = []

    pbar = tqdm.tqdm(val_loader)
    pbar.set_description(f"Epoch {epoch} Validation")
    for batch in pbar:
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        val_CE_loss.append(torch.nn.CrossEntropyLoss()(outputs.logits, labels).item())
        val_f1.append(torchmetrics.classification.MulticlassF1Score(num_classes=num_labels, average='weighted').to(device)(outputs.logits, labels).item())
    
    print(f"Epoch {epoch} CrossEntropy Val loss: {np.mean(val_CE_loss)}")
    print(f"Epoch {epoch} F1 Val score: {np.mean(val_f1)}")
    
    # # Saving Model    
    # if epoch % 10 == 0:
    #     torch.save(model.state_dict(), paths.MODEL_PATH/f"line-label_medBERT-finetuned_{epoch}.pt")


In [None]:
torch.cuda.empty_cache()

In [None]:
# Free GPU Memory
torch.cuda.empty_cache()
del input_ids
del attention_mask
del labels

In [None]:
paths.MODEL_PATH/f"line-label_medBERT-finetuned_{epoch}.pt"

In [None]:
model.eval()

# Evaluate on test set
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)
outputs = []

with torch.no_grad():
    for batch in tqdm.tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs.append(model(input_ids, attention_mask=attention_mask))


In [None]:
# Get predictions
preds = [np.argmax(output.logits.cpu().numpy(), axis=1) for output in outputs]
preds = np.concatenate(preds)

# Get true labels
true = np.argmax(test_labels_enc, axis=1)

# Calculate accuracy
acc = np.sum(preds == true) / len(true)

# F1 Score
from sklearn.metrics import f1_score
f1 = f1_score(true, preds, average='weighted')
print(f"Accuracy: {acc}")
print(f"F1 Score: {f1}")

In [None]:
# Setting pooling and head to trainable
for name, param in trainer.model.named_parameters():
    if "pooler" in name or "classifier" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

In [None]:
# Train
trainer.train()

In [None]:
# Save model
trainer.save_model(os.path.join(paths.MODEL_PATH, "medbert-diag-label"))

In [None]:
# Evaluate
trainer.evaluate()