In [1]:
import os
def change_to_code_dir():
    current_dir = os.getcwd()
    if os.path.basename(current_dir) == "p2ch15":
        parent_dir = os.path.dirname(current_dir)
        os.chdir(parent_dir)    
change_to_code_dir()

In [2]:
import torch
from PIL import Image
from p2ch15.utils import FineTuningDataset, plot_tensor_histogram, plot_mask
from torch.utils.data import DataLoader, Dataset
from transformers import SegformerImageProcessor
from datetime import datetime
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

fine_tuning_dir = "data-unversioned/part2/fine-tuning/dataset"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from p2ch15.utils import FineTuningDataset
train_dataset = FineTuningDataset(split="train")
val_dataset = FineTuningDataset(split="val")
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8)

In [4]:
from transformers import SegformerForSemanticSegmentation

id2label = {"0": "background", "1": "nodule"}
label2id = {v: k for k, v in id2label.items()}

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0",
    num_labels=2,
    id2label=id2label,
    label2id=label2id,
)
model.to(device)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)

In [6]:
image_processor = SegformerImageProcessor()
def encode_inputs_for_model(image_paths, masks_paths=[]):
    images = [Image.open(path) for path in image_paths]
    masks = [Image.open(path) for path in masks_paths] or None
    encoded_inputs = image_processor(images, masks, return_tensors="pt")
    return encoded_inputs["pixel_values"].to(device), encoded_inputs["labels"].to(device)

# example usage
for batch in train_dataloader:
    break

image_paths, masks_paths = batch["ct_image_path"], batch["mask_image_path"]
pixel_values, labels = encode_inputs_for_model(image_paths, masks_paths)
print(pixel_values.shape, labels.shape)

torch.Size([8, 3, 512, 512]) torch.Size([8, 512, 512])


In [8]:
from torch.utils.tensorboard import SummaryWriter
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = f"runs/p2ch15/fine_tune_{timestamp}"
writer = SummaryWriter(log_dir=log_dir)

num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    print("Epoch:", epoch)
    
    # Train
    total_train_loss = 0
    num_train_batches = 0
    for idx, batch in enumerate(tqdm(train_dataloader)):
        pixel_values, labels = encode_inputs_for_model(batch["ct_image_path"], batch["mask_image_path"])
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        num_train_batches += 1

    # Validation
    model.eval()
    total_val_loss = 0
    num_val_batches = 0
    with torch.no_grad():
        for batch in val_dataloader:
            pixel_values, labels = encode_inputs_for_model(batch["ct_image_path"], batch["mask_image_path"])
            outputs = model(pixel_values=pixel_values, labels=labels)
            val_loss = outputs.loss
            total_val_loss += val_loss.item()
            num_val_batches += 1

    # Calculate average train and validation loss
    average_train_loss = total_train_loss / num_train_batches
    average_val_loss = total_val_loss / num_val_batches
    print(f"Training Loss: {average_train_loss:.4f}, Validation Loss: {average_val_loss:.4f}")
    # Log the losses to TensorBoard
    writer.add_scalar('Loss/Train', average_train_loss, epoch)
    writer.add_scalar('Loss/Validation', average_val_loss, epoch)
writer.close()

Epoch: 0


  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:26<00:00,  1.89it/s]


Training Loss: 0.0214, Validation Loss: 0.0174
Epoch: 1


100%|██████████| 50/50 [00:24<00:00,  2.05it/s]


Training Loss: 0.0194, Validation Loss: 0.0161
Epoch: 2


100%|██████████| 50/50 [00:27<00:00,  1.83it/s]


Training Loss: 0.0171, Validation Loss: 0.0149
Epoch: 3


100%|██████████| 50/50 [00:25<00:00,  1.93it/s]


Training Loss: 0.0156, Validation Loss: 0.0129
Epoch: 4


100%|██████████| 50/50 [00:25<00:00,  1.94it/s]


Training Loss: 0.0139, Validation Loss: 0.0118
Epoch: 5


100%|██████████| 50/50 [00:25<00:00,  1.97it/s]


Training Loss: 0.0128, Validation Loss: 0.0112
Epoch: 6


100%|██████████| 50/50 [00:26<00:00,  1.92it/s]


Training Loss: 0.0115, Validation Loss: 0.0101
Epoch: 7


100%|██████████| 50/50 [00:25<00:00,  1.96it/s]


Training Loss: 0.0107, Validation Loss: 0.0085
Epoch: 8


100%|██████████| 50/50 [00:25<00:00,  1.98it/s]


Training Loss: 0.0099, Validation Loss: 0.0085
Epoch: 9


100%|██████████| 50/50 [00:25<00:00,  1.99it/s]


Training Loss: 0.0092, Validation Loss: 0.0080
Epoch: 10


100%|██████████| 50/50 [00:24<00:00,  2.01it/s]


Training Loss: 0.0085, Validation Loss: 0.0077
Epoch: 11


100%|██████████| 50/50 [00:25<00:00,  1.99it/s]


Training Loss: 0.0079, Validation Loss: 0.0071
Epoch: 12


100%|██████████| 50/50 [00:27<00:00,  1.82it/s]


Training Loss: 0.0075, Validation Loss: 0.0063
Epoch: 13


100%|██████████| 50/50 [00:25<00:00,  1.94it/s]


Training Loss: 0.0070, Validation Loss: 0.0062
Epoch: 14


100%|██████████| 50/50 [00:26<00:00,  1.89it/s]


Training Loss: 0.0065, Validation Loss: 0.0057
Epoch: 15


100%|██████████| 50/50 [00:26<00:00,  1.87it/s]


Training Loss: 0.0062, Validation Loss: 0.0052
Epoch: 16


100%|██████████| 50/50 [00:26<00:00,  1.92it/s]


Training Loss: 0.0058, Validation Loss: 0.0048
Epoch: 17


100%|██████████| 50/50 [00:26<00:00,  1.91it/s]


Training Loss: 0.0055, Validation Loss: 0.0045
Epoch: 18


100%|██████████| 50/50 [00:26<00:00,  1.92it/s]


Training Loss: 0.0052, Validation Loss: 0.0048
Epoch: 19


100%|██████████| 50/50 [00:25<00:00,  1.97it/s]


Training Loss: 0.0049, Validation Loss: 0.0042


In [9]:
torch.save(model.state_dict(), "p2ch15/segformer_epoch_20.pt")

In [None]:
new_model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/mit-b0",
    num_labels=2,
    id2label=id2label,
    label2id=label2id,
)
state_dict = torch.load("p2ch15/segformer_epoch_20.pt")
new_model.load_state_dict(state_dict)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  state_dict = torch.load("p2ch15/segformer_epoch_10.pt")


<All keys matched successfully>