In [1]:
import torch 
import torch.nn as nn
# from models import SegFormer_CS444
import transformers
from dataset import ContrailsDataset
%load_ext autoreload
%autoreload 2

In [2]:
IMAGE_SIZE=256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = transformers.SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b3", num_labels=2, image_size=IMAGE_SIZE).to(device) #<- head not pretrtained, we finetune head
# overwrite segformer head with our own modifications to use some new tricks
from models import SegformerDecodeHeadModified
model.decode_head = SegformerDecodeHeadModified(model.config).to(device)
model.train()
for param in model.parameters():
    param.requires_grad = False
unfreeze_layers = ['segformer.encoder.patch_embeddings', 'segformer.encoder.block.2', 'segformer.encoder.block.3', 'segformer.encoder.layer_norm', 'decode_head']  
for name, param in model.named_parameters():
    for layer_name in unfreeze_layers:
        if layer_name in name:
            param.requires_grad = True
            break

  return self.fget.__get__(instance, owner)()
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
import evaluate

metric = evaluate.load("mean_iou")

In [4]:
import lightning as L
import torch.optim as optim
torch.set_float32_matmul_precision('medium')
train_dataset = ContrailsDataset("/data/contrails/train")
val_dataset =  ContrailsDataset("/data/contrails/validation")
class LitSegDeg(L.LightningModule):
    def __init__(self, model, lr=1e-4, batch_size=32):
        super().__init__()
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
    def train_dataloader(self):
        return torch.utils.data.DataLoader(train_dataset, num_workers=4, persistent_workers=True, batch_size=self.batch_size, prefetch_factor=8)
    def val_dataloader(self):
        return torch.utils.data.DataLoader(val_dataset, num_workers=4, persistent_workers=True, batch_size=self.batch_size, prefetch_factor=8)
    def validation_step(self, batch, batch_idx):
        x, y = batch
        z = self.model(pixel_values=x, labels=y)
        loss, logits = z.loss, z.logits
        labels = y
        self.log("val_loss", loss)
        if batch_idx % 50 == 0:
            with torch.no_grad():
                upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
                predicted = upsampled_logits.argmax(dim=1)

                # note that the metric expects predictions + labels as numpy arrays
                metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())
            metrics = metric._compute(
                    predictions=predicted.cpu(),
                    references=labels.cpu(),
                    num_labels=2,
                    ignore_index=255,
                    reduce_labels=False, # we've already reduced the labels ourselves
            )
            self.log("val_mean_iou", metrics["mean_iou"])
        return z.loss
    def training_step(self, batch, batch_idx):
        x, y = batch
        z = self.model(pixel_values=x, labels=y)
        loss, logits = z.loss, z.logits
        labels = y
        self.log("train_loss", loss)
        if batch_idx % 50 == 0:
            with torch.no_grad():
                upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
                predicted = upsampled_logits.argmax(dim=1)

                # note that the metric expects predictions + labels as numpy arrays
                metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())
            metrics = metric._compute(
                    predictions=predicted.cpu(),
                    references=labels.cpu(),
                    num_labels=2,
                    ignore_index=255,
                    reduce_labels=False, # we've already reduced the labels ourselves
            )
            self.log("train_mean_iou", metrics["mean_iou"])
        return z.loss
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

In [5]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint

print("Starting model training...")
checkpoint_callback = ModelCheckpoint(save_top_k=2, monitor="val_loss")
early_stop_callback = EarlyStopping(
    monitor="val_loss", 
    min_delta=0.00, 
    patience=3, 
    verbose=False, 
    mode="min",
)
    
l_model = LitSegDeg(model, batch_size=16)
trainer = L.Trainer(max_epochs=20, log_every_n_steps=5, callbacks=[checkpoint_callback, early_stop_callback], val_check_interval=400)
trainer.fit(model=l_model)

Starting model training...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/dsingh/source/devksingh4/transfer-vit-unet/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                             | Params
-----------------------------------------------------------
0 | model | SegformerForSemanticSegmentation | 50.4 M
-----------------------------------------------------------
47.6 M    Trainable params
2.8 M     Non-trainable params
50.4 M    Total params
201.516   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]