<a href="https://colab.research.google.com/github/kyileiaye2021/HistoGPT-Teacher-Student-Network/blob/main/HistoGPT_Teacher_Student_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup environment

In [1]:
# install openslide dependencies
!sudo apt-get install openslide-tools
!sudo apt-get install python-openslide
!pip install openslide-python

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
openslide-tools is already the newest version (3.4.1+dfsg-5build1).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
E: Unable to locate package python-openslide


In [2]:
# install flamingo and histogpt
!pip install flamingo-pytorch --no-deps
!pip install git+https://github.com/marrlab/HistoGPT.git

Collecting git+https://github.com/marrlab/HistoGPT.git
  Cloning https://github.com/marrlab/HistoGPT.git to /tmp/pip-req-build-spq7nhzw
  Running command git clone --filter=blob:none --quiet https://github.com/marrlab/HistoGPT.git /tmp/pip-req-build-spq7nhzw
  Resolved https://github.com/marrlab/HistoGPT.git to commit 35feddc2b5833676e9e8f09ee432b548a2a75e46
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [5]:
# check whether to use a gpu or cpu
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# Mounting the Google Drive

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Load the HistoGPT Teacher model (original model)

In [7]:
from transformers import BioGptConfig
from histogpt.models import HistoGPTForCausalLM, PerceiverResamplerConfig

histogpt_teacher = HistoGPTForCausalLM(BioGptConfig(), PerceiverResamplerConfig())
histogpt_teacher = histogpt_teacher.to(device)
PATH = '/content/drive/MyDrive/Teacher_Student_Network/Histo GPT/Histogpt_weights/histogpt-1b-6k-pruned.pth' # histogpt weight
state_dict = torch.load(PATH, map_location=device)
histogpt_teacher.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# Checking HistoGPT Teacher model parameters

In [8]:
for name, param in histogpt_teacher.named_parameters():
    print(name)


histogpt.perceiver_resampler.media_pos
histogpt.perceiver_resampler.latents
histogpt.perceiver_resampler.linear.weight
histogpt.perceiver_resampler.layers.0.0.norm_media.weight
histogpt.perceiver_resampler.layers.0.0.norm_media.bias
histogpt.perceiver_resampler.layers.0.0.norm_latents.weight
histogpt.perceiver_resampler.layers.0.0.norm_latents.bias
histogpt.perceiver_resampler.layers.0.0.to_q.weight
histogpt.perceiver_resampler.layers.0.0.to_kv.weight
histogpt.perceiver_resampler.layers.0.0.to_out.weight
histogpt.perceiver_resampler.layers.0.1.0.weight
histogpt.perceiver_resampler.layers.0.1.0.bias
histogpt.perceiver_resampler.layers.0.1.1.weight
histogpt.perceiver_resampler.layers.0.1.3.weight
histogpt.perceiver_resampler.layers.1.0.norm_media.weight
histogpt.perceiver_resampler.layers.1.0.norm_media.bias
histogpt.perceiver_resampler.layers.1.0.norm_latents.weight
histogpt.perceiver_resampler.layers.1.0.norm_latents.bias
histogpt.perceiver_resampler.layers.1.0.to_q.weight
histogpt.per

# Freezing All Trainable Parameters in HistoGPT Teacher Model

In [9]:
# Freezing all parameters
for param in histogpt_teacher.parameters():
    param.requires_grad = False


In [10]:
# check if all parameters are freezed
for name, param in histogpt_teacher.named_parameters():
  if param.requires_grad:
    print(name)

# Loading HistoGPT Student Model

In [11]:
from transformers import BioGptConfig
from histogpt.models import HistoGPTForCausalLM, PerceiverResamplerConfig

histogpt_student = HistoGPTForCausalLM(BioGptConfig(), PerceiverResamplerConfig())
histogpt_student = histogpt_student.to(device)
PATH = '/content/drive/MyDrive/Teacher_Student_Network/Histo GPT/Histogpt_weights/histogpt-1b-6k-pruned.pth' # histogpt weight
state_dict = torch.load(PATH, map_location=device)
histogpt_student.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

# Unfreezing Certain Parameters in HistoGPT Student Model

In [32]:
# Freezing all parameters first
for name, param in histogpt_student.named_parameters():
    param.requires_grad = False

# Unfreezing the last layer of perceiver sampler (slide level encoder)
for name, param in histogpt_student.named_parameters():
  if 'perceiver_resampler.layers.5' in str(name):
    param.requires_grad = True

# Unfreezing the exist layer of perceiver sampler
for name, param in histogpt_student.named_parameters():
  if 'histogpt.perceiver_exitgate.weight' in name:
    param.requires_grad = True

In [33]:
# check if all parameters are freezed
for name, param in histogpt_student.named_parameters():
  if param.requires_grad:
    print(name)

histogpt.perceiver_resampler.layers.5.0.norm_media.weight
histogpt.perceiver_resampler.layers.5.0.norm_media.bias
histogpt.perceiver_resampler.layers.5.0.norm_latents.weight
histogpt.perceiver_resampler.layers.5.0.norm_latents.bias
histogpt.perceiver_resampler.layers.5.0.to_q.weight
histogpt.perceiver_resampler.layers.5.0.to_kv.weight
histogpt.perceiver_resampler.layers.5.0.to_out.weight
histogpt.perceiver_resampler.layers.5.1.0.weight
histogpt.perceiver_resampler.layers.5.1.0.bias
histogpt.perceiver_resampler.layers.5.1.1.weight
histogpt.perceiver_resampler.layers.5.1.3.weight
histogpt.perceiver_exitgate.weight


### Dataset Class

In [26]:
!pip install timm



In [27]:
from torch.utils.data import Dataset, DataLoader
import torchvision
import os
from PIL import Image
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD

class AugmentedImageDataset(Dataset):
  def __init__(self, oct_dir, he_dir):
    self.oct_dir = oct_dir
    self.he_dir = he_dir

    self.paired_files = []
    he_files = os.listdir(he_dir)
    he_files_lower = {he_file.lower(): he_file for he_file in he_files}
    for oct_file in os.listdir(oct_dir):
      oct_file_lower = oct_file.lower()

      # replace '_real_a' with '_real_b' to find the matching files
      if '_real_a' in oct_file_lower:
        he_file = oct_file_lower.replace('_real_a', '_real_b')

      elif '_fake_a' in oct_file_lower:
        he_file = oct_file_lower.replace('_fake_a', '_fake_b')

      if he_file in he_files_lower:
        he_file = he_files_lower[he_file]
        self.paired_files.append((os.path.join(oct_dir, oct_file), os.path.join(he_dir, he_file)))
      else:
        print(f'Warning: {he_file} not found in {he_dir}')


    img_size = 384
    self.transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(img_size, interpolation=3, antialias=True),
        torchvision.transforms.CenterCrop((img_size, img_size)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
    ])

  def __len__(self):
    return len(self.paired_files)

  def __getitem__(self, idx):
    oct_path, he_path = self.paired_files[idx] # retrieve the paired (oct, h&e) file names
    oct_image = Image.open(oct_path).convert('RGB')
    he_image = Image.open(he_path).convert('RGB')

    return self.transform(oct_image), self.transform(he_image)



### Prepare Data Loader

In [28]:
# create config dictionary
train_config = {
    "train_folder": "/content/drive/MyDrive/Teacher_Student_Network/Split Dataset/train",
    "he_val_folder": "/content/drive/MyDrive/Teacher_Student_Network/Split Dataset/he_val",
    "vhe_val_folder": "/content/drive/MyDrive/Teacher_Student_Network/Split Dataset/vhe_val",
    "batch_size": 128, # Reduced batch size
    "epochs": 300,
    "shuffle_train": True,
    "num_workers": 2,
    "pin_memory": True,
    "init_lr": 1e-6,
    "weight_decay": 0.05,
    "epochs_between_save": 5,
    "epochs_between_val": 5,
    "patience": 5,  # for early stopping
    "output_dir_path": "/content/drive/MyDrive/Teacher_Student_Network/model_savepoints"
}

train_oct_folder = os.path.join(train_config['train_folder'],'OCT')
train_he_folder = os.path.join(train_config['train_folder'],'H&E')

val_he_oct_folder = os.path.join(train_config['he_val_folder'], 'OCT')
val_he_folder = os.path.join(train_config['he_val_folder'], 'H&E')
val_vhe_oct_folder = os.path.join(train_config['vhe_val_folder'], 'OCT')
val_vhe_folder = os.path.join(train_config['vhe_val_folder'], 'vH&E')

train_dataset = AugmentedImageDataset(train_oct_folder, train_he_folder)
he_val_dataset = AugmentedImageDataset(val_he_oct_folder, val_he_folder)
vhe_val_dataset = AugmentedImageDataset(val_vhe_oct_folder, val_vhe_folder)

train_loader = DataLoader(
    train_dataset,
    batch_size=train_config['batch_size'],
    shuffle=train_config['shuffle_train'],
    num_workers=train_config['num_workers'],
    pin_memory=train_config['pin_memory']
)

he_val_loader = DataLoader(
    he_val_dataset,
    batch_size=train_config['batch_size'],
    shuffle=False, # usually false for val
    num_workers=train_config['num_workers'],
    pin_memory=train_config['pin_memory']
)

vhe_val_loader = DataLoader(
    vhe_val_dataset,
    batch_size=train_config['batch_size'],
    shuffle=False, # usually false for val
    num_workers=train_config['num_workers'],
    pin_memory=train_config['pin_memory']
)



In [29]:
print(f"Total number of H&E-OCT pairs batches in each epoch in train set: {len(train_loader)}")
print(f"Total number of H&E-OCT pairs batches in each epoch in val set: {len(he_val_loader)}")
print(f"Total number of vH&E-OCT pairs batches in each epoch in val set: {len(vhe_val_loader)}")

Total number of H&E-OCT pairs batches in each epoch in train set: 45
Total number of H&E-OCT pairs batches in each epoch in val set: 5
Total number of vH&E-OCT pairs batches in each epoch in val set: 5


### Implementing Loss Function

In [None]:
# import torch.nn.functional as F
# def cosine_loss(student_emb, teacher_emb):
#   return 1 - F.cosine_similarity(student_emb, teacher_emb, dim=1).mean()
import torch.nn.functional as F
def contrastive_ce(student_emb, teacher_emb, temp = 0.07):
  '''
  Contrastive cross-entropy loss (InfoNCE-style).
  COmpares each student embedding to all teacher embeddings in the batch.
  '''

  # Normalize embeddings
  student_emb = F.normalize(student_emb, dim=-1)
  teacher_emb = F.normalize(teacher_emb, dim=-1)

  # Compute similarity matrix: (batch_size, batch_size)
  logits = torch.matmul(student_emb, teacher_emb.T)
  logits /= temp # apply temp scaling

  # Ground Truth: i-th student <-> i-th teacher
  targets = torch.arange(logits.size(0), device=logits.device) # this tells that H&E correct match of ith OCT is in index i

  # Compute cross entropy loss
  loss = F.cross_entropy(logits, targets)

  return loss


### Loading the saved model checkpoint

In [None]:
def load_checkpoint(student, teacher, optimizer, config, scheduler=None):
  checkpoint_path = os.path.join(config['output_dir_path'], 'larger batch size 128','best_model.pth')

  if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    student.load_state_dict(checkpoint['student_state_dict'])

    if teacher and checkpoint['teacher_state_dict']:
      teacher.load_state_dict(checkpoint['teacher_state_dict'])

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # Manually override the learning rate from the checkpoint
    for param_group in optimizer.param_groups:
      param_group['lr'] = config['init_lr']  # Or directly use 1e-5
    print(f"Learning rate reset to: {optimizer.param_groups[0]['lr']}")


    if scheduler and checkpoint['scheduler_state_dict']:
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    best_val_loss = checkpoint.get('val_loss', float('inf'))
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
    print(f"Best val loss: {best_val_loss}")
    return checkpoint['epoch'] + 1, best_val_loss # Resume from next epoch

  else:
    print("No checkpoint found. Starting from scratch.")
    return 0, float('inf')  # Start from epoch 0

### Saving the trained model checkpoint

In [None]:
def save_checkpoint(student, teacher, optimizer, epoch, config, scheduler=None, val_loss=None, filename='latest_checkpoint.pth'):
  print('Saving to: ', config['output_dir_path'])
  os.makedirs(config['output_dir_path'], exist_ok=True)

  best_model_path = os.path.join(config['output_dir_path'], 'larger batch size 128', 'best_model_2.pth')
  torch.save({
      'epoch': epoch,
      'student_state_dict': student.state_dict(),
      'teacher_state_dict': teacher.state_dict() if teacher else None,
      'optimizer_state_dict': optimizer.state_dict(),
      'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
      'val_loss': val_loss
  }, best_model_path)

  print(f"Saved checkpoint at epoch {epoch} with val_loss: {val_loss}")

### Training the model

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Filter parameters that require gradients
# trainable_params = [p for p in student_model.parameters() if p.requires_grad]
optimizer = Adam(student_model.parameters(), lr=train_config['init_lr'], weight_decay=train_config['weight_decay'])
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',  # minimize val loss
    factor=0.5,  # reduce LR by half
    patience=1,  # wait 1 epochs of no improvement
    threshold=1e-4, # minimum change to qualify as improvement
    verbose=True # print LR changes
)
print(optimizer.param_groups[0]['lr'])
print(optimizer.param_groups[0]['weight_decay'])

In [None]:
from torch.utils.tensorboard import SummaryWriter # for visualizing the loss

writer = SummaryWriter(log_dir = 'runs/student_training_run')
from torch.cuda.amp import autocast, GradScaler
def train_student(num_epochs, train_loader, he_val_loader, vhe_val_loader, student_model, teacher_model, optimizer, device, scheduler):

  scaler = GradScaler()  # Needed for mixed precision

  # for early stopping
  epochs_without_improvement = 0
  patience = train_config['patience']

  # Load from checkpoint if available
  start_epoch, best_val_loss = load_checkpoint(student_model, teacher_model, optimizer, train_config, scheduler)
  # best_val_loss = float('inf')
  # start_epoch = 0

  ### Training Loop #####
  for epoch in range(start_epoch, num_epochs):
    student_model.train()  # student in train mode
    total_loss = 0

    for batch_idx, (oct_images, he_images) in enumerate(train_loader):
      oct_images = oct_images.to(device)
      he_images = he_images.to(device)

      with torch.no_grad():
        teacher_emb = teacher_model(image=he_images, with_head=True, out_norm=True, ms_aug=False, return_global=True)[0]
      teacher_emb = teacher_emb.detach()

      with autocast(dtype=torch.float16):
        student_emb = student_model(image=oct_images, with_head=True, out_norm=True, ms_aug=False, return_global=True)[0]
        loss = contrastive_ce(student_emb, teacher_emb)
        # print(f"Student emb mean: {student_emb.mean().item():.4f}, Teacher emb mean: {teacher_emb.mean().item():.4f}, Loss: {loss.item():.6f}")

        if torch.isnan(student_emb).any() or torch.isnan(teacher_emb).any():
          print(f"NaNs detected at batch {batch_idx+1}")

        if student_emb.abs().max() < 1e-6:
          print(f"Student embedding is close to zero at batch {batch_idx+1}")

      optimizer.zero_grad() # zero out gradient
      scaler.scale(loss).backward() # backpropagation
      scaler.unscale_(optimizer)
      torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
      scaler.step(optimizer) # update parameters (weight & biases)
      scaler.update()

      total_loss += loss.item()

      if (batch_idx + 1) % 15 == 0: # print the loss in every 100 batch in each epoch
        print(f"Epoch [{epoch + 1}/{num_epochs}] | Batch [{batch_idx + 1}/{len(train_loader)}] | Batch Loss: {loss.item():.4f}")

      # log the loss every 100 batches in each epoch
      global_step = epoch * len(train_loader) + batch_idx
      writer.add_scalar('Loss/train', loss.item(), global_step)

    ### Training Loop ####

    ### Printing average loss after each epoch ####
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {average_loss:.4f}")
    # log average loss at the end of each epoch
    writer.add_scalar("Loss/epoch_train", average_loss, epoch)
    # log learning rate to tensorboard
    writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], epoch)
    ### Printing average loss after each epoch ####

    torch.cuda.empty_cache()  # Clean memory after each epoch

    ### Evaluate every 10 epochs with validation dataset ####
    if (epoch + 1) % train_config['epochs_between_val'] == 0:
      print(f"Evaluating at epoch {epoch + 1}")
      student_model.eval() # evaluation mode
      he_val_loss_total = 0
      vhe_val_loss_total = 0

      with torch.no_grad():
        for oct_images, he_images in he_val_loader:
          oct_images = oct_images.to(device)
          he_images = he_images.to(device)

          teacher_emb = teacher_model(image=he_images, with_head=True, out_norm=True, ms_aug=False, return_global=True)[0]
          student_emb = student_model(image=oct_images, with_head=True, out_norm=True, ms_aug=False, return_global=True)[0]

          val_loss = contrastive_ce(student_emb, teacher_emb)
          he_val_loss_total += val_loss.item()

      with torch.no_grad():
        for oct_images, vhe_images in vhe_val_loader:
          oct_images = oct_images.to(device)
          vhe_images = vhe_images.to(device)

          teacher_emb = teacher_model(image=vhe_images, with_head=True, out_norm=True, ms_aug=False, return_global=True)[0]
          student_emb = student_model(image=oct_images, with_head=True, out_norm=True, ms_aug=False, return_global=True)[0]

          val_loss = contrastive_ce(student_emb, teacher_emb)
          vhe_val_loss_total += val_loss.item()

      avg_he_val_loss = he_val_loss_total / len(he_val_loader)
      avg_vhe_val_loss = vhe_val_loss_total / len(vhe_val_loader)

      print(f"Validation Loss (H&E) at Epoch {epoch + 1}: {avg_he_val_loss:.4f}")
      print(f"Validation Loss (vH&E) at Epoch {epoch + 1}: {avg_vhe_val_loss:.4f}")

      avg_val_loss = (avg_he_val_loss + avg_vhe_val_loss) / 2
      print(f"Average Validation Loss at Epoch {epoch + 1}: {avg_val_loss:.4f}")


      scheduler.step(avg_val_loss) # update the learning rate
      current_lr = optimizer.param_groups[0]['lr']
      print(f"Learning Rate after scheduler step: {current_lr}")

      torch.cuda.empty_cache()  # Clean memory after each epoch

      if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss # update the curr avg val loss if the avg val loss is less than the prev avg val loss
        print(f'New best val loss: {best_val_loss}')
        epochs_without_improvement = 0
        save_checkpoint(student_model, teacher_model, optimizer, epoch, train_config, scheduler, best_val_loss, filename=f"best_model.pth")

      else:
        epochs_without_improvement += 1

      if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch + 1}") # to prevent overfitting
        break

writer.close()

In [None]:
train_student(
    num_epochs=train_config['epochs'],
    train_loader=train_loader,
    he_val_loader=he_val_loader,
    vhe_val_loader=vhe_val_loader,
    student_model=histogpt_student,
    teacher_model=histogpt_teacher,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler
  )