In [1]:
import time
import math
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from xformers.ops import memory_efficient_attention  # Flash Attention
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
#https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions



# Load dataset (Ensure the file is downloaded from Kaggle: https://www.kaggle.com/tboyle10/medicaltranscriptions)
df = pd.read_csv("mtsamples.csv")

# Select relevant columns (assuming 'description' as text and 'medical_specialty' as labels)
df = df[['transcription', 'medical_specialty']].dropna()

# Reduce the number of categories for a binary classification task
df['LABEL'] = df['medical_specialty'].apply(lambda x: 1 if x == ' Surgery' else 0)

# Splitting dataset into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['transcription'].values, df['LABEL'].values, test_size=0.2, random_state=42
)

# Tokenization
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

class MedicalDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer.encode_plus(
            self.texts[idx],
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# Define dataset parameters
max_len = 512
batch_size = 256

train_dataset = MedicalDataset(train_texts, train_labels, tokenizer, max_len)
val_dataset = MedicalDataset(val_texts, val_labels, tokenizer, max_len)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


In [4]:
df['LABEL'].value_counts()

LABEL
0    3878
1    1088
Name: count, dtype: int64

In [5]:
len(train_texts)

3972

In [6]:
len(val_texts)

994

In [7]:
next(iter(train_loader))

{'input_ids': tensor([[  101,  1055,  1011,  ...,     0,     0,     0],
         [  101,  4241, 19386,  ...,     0,     0,     0],
         [  101,  2381,  1997,  ...,  7175,  4030,   102],
         ...,
         [  101,  3653, 25918,  ..., 10814, 10440,   102],
         [  101,  3653, 25918,  ...,  2059,  2741,   102],
         [  101,  3114,  2005,  ...,  8995,  5751,   102]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'labels': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
         1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [8]:
curr_max = 0
for i in train_loader:
    temp_max = i['input_ids'].max().detach().cpu().numpy()
    if temp_max > curr_max:
        curr_max = temp_max

In [9]:
for i in val_loader:
    temp_max = i['input_ids'].max().detach().cpu().numpy()
    if temp_max > curr_max:
        curr_max = temp_max

In [10]:
vocab_size = curr_max + 1

In [11]:
class PositionalEncoding(nn.Module):
    # Cool thread to visualize it: https://datascience.stackexchange.com/questions/51065/what-is-the-positional-encoding-in-the-transformer-model
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model, device=device)
        position = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(position * div_term[:pe[:, 1::2].shape[1]])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  
        self.register_buffer('pe', pe)
    #Layer that adds the encoding.
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [12]:
class FlashAttentionLayer(nn.Module):
    # Quick flash attention implementation
    def __init__(self, embed_dim, nhead, dropout=0.1):
        super().__init__()
        assert embed_dim % nhead == 0, "Embedding dimension must be divisible by number of heads"
        self.nhead = nhead
        self.head_dim = embed_dim // nhead  # Get head dimensions
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.attn_proj_q = nn.Linear(embed_dim, embed_dim)
        self.attn_proj_k = nn.Linear(embed_dim, embed_dim)
        self.attn_proj_v = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.feedforward = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, L, C = x.shape  # Batch, Sequence Length, Embedding Dim
        # Project linearly for Q, K, V
        q = self.attn_proj_q(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2)  # (B, nh, L, head_dim)
        k = self.attn_proj_k(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2)
        v = self.attn_proj_v(x).view(B, L, self.nhead, self.head_dim).transpose(1, 2)

        # Apply Flash Attention
        attn_output = memory_efficient_attention(q, k, v)  # (B, nh, L, head_dim)

        attn_output = attn_output.transpose(1, 2).reshape(B, L, C)  # Reshape back to (B, L, C)
        attn_output = self.out_proj(attn_output)
        x = self.norm1(x + self.dropout(attn_output))

        ff_output = self.feedforward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x


In [13]:

#Important to note the many parameters here, that's what we're optimizing
class SmallFlashTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, nhead, num_layers, dropout, num_classes=2):
        super().__init__()
        #Make embedding and encoding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim)
        # Note: You can stack as many attention layers as you want
        self.layers = nn.ModuleList([
            FlashAttentionLayer(embed_dim, nhead, dropout) for _ in range(num_layers)
        ])
        # Linear head
        # Add extra layers for better feature extraction
        self.norm = nn.LayerNorm(embed_dim)
        self.fc1 = nn.Linear(embed_dim, embed_dim * 2)
        self.activation = nn.GELU()
        self.fc2 = nn.Linear(embed_dim * 2, num_classes)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()
        # Note how init weights is applied
        self.apply(self._init_weights)
    # Note the model initialization and weights
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
            
    def forward(self, x, attention_mask=None):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        for layer in self.layers:
            x = layer(x)
        
        if attention_mask is not None:
            attention_mask = attention_mask.float()
        # Note the attention mask that is used to only pay attention "backwards" through the text.
        #Not necessary for nucleic acids
        if attention_mask is not None:
            x = x * attention_mask.unsqueeze(-1)
            # Mean pooling with mask - Note the representation used - Mean token representation here:
            #x = x.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
            # Last token representation version
            x = x[:,-1,:] / attention_mask.sum(dim=1, keepdim=True).clamp(min=1e-9)
        else:
            #x = x.mean(dim=1)
            x = x[:,-1,:]
        # Note the Fully Connected head built on top of the attention layers
        x = self.norm(x)
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [14]:

# You need to generate a training function, it's also good practice to do so.
def train_epoch(model, dataloader, optimizer, scheduler, criterion):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        inputs = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs, attention_mask)
        # Debugging output shapes
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item() * inputs.size(0)
        scheduler.step()
    return total_loss / len(dataloader.dataset)

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(inputs)
            # Debugging output shapes
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
    avg_loss = total_loss / len(dataloader.dataset)
    accuracy = correct / len(dataloader.dataset)
    return avg_loss, accuracy





In [None]:


# For optuna, it's good to generate a function that runs all your suggestions through.
def objective(trial):
    #This is what we'll test, but it could be anything, HEHE!
    embed_dim = trial.suggest_categorical("embed_dim", [32, 64, 128])
    nhead = trial.suggest_categorical("nhead", [1, 2, 4])
    num_layers = trial.suggest_int("num_layers", 1, 2)
    dropout = trial.suggest_float("dropout", 0.1, 0.5)
    lr = trial.suggest_float("lr", 1e-4, 1e-2)
    batch_size = trial.suggest_categorical("batch_size", [128,256,512])
    # Make the transformer
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    model = SmallFlashTransformer(vocab_size, embed_dim, nhead, num_layers, dropout).to(device)
    #NOTE this: Weighted classes for CrossEntropy loss
    class_counts = df['LABEL'].value_counts().to_list()
    weights = [1.0 / count for count in class_counts]
    class_weights = torch.tensor(weights, dtype=torch.float).to(device)
    
    #Note the weight and label smoothing
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    #Note the difference between Adam and Adam Weight Decay
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    #Note difference between StepLR and OneCycleLR
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) 
    scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            epochs=5,
            steps_per_epoch=len(train_loader),
            pct_start=0.1
        )
    # Generate data loaders
    # Set how many epocs you want
    num_epochs = 10
    for epoch in range(num_epochs):
        print(epoch)
        train_loss = train_epoch(model, train_loader, optimizer, scheduler,criterion)
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        print(val_acc)
        trial.report(val_loss, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return val_loss
# Do you want to minimize or maximize the objective (val_loss?)
study = optuna.create_study(direction="minimize")
# How many trials? optuna has several optimization algoirthms included
study.optimize(objective, n_trials=20, timeout=600)
#https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/003_efficient_optimization_algorithms.html Oh god, there's so many of them.
print("Best trial:")
trial = study.best_trial
print("  Validation Loss:", trial.value)
print("  Best hyperparameters:", trial.params)

[I 2025-03-11 13:42:12,831] A new study created in memory with name: no-name-828eca21-7be0-49c4-9261-22bee44cacbb


0
0.23138832997987926
1
