In [1]:
from custom_dataset.segdataset import InstanceSegmentationDataset
from torch.utils.data import DataLoader
import json

import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import os
import evaluate
import logging
from utils import rs_utils
from datetime import datetime
from transformers import SegformerImageProcessor

In [2]:
#--- root 
root_dir = '/disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe'

#--- time
current_time = datetime.now()
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")

In [3]:
#-----------
_VERSION = "005"
_EPOCHS = 50
_MODEL_SAVE = True
_MODEL_VERSION = "nvidia/mit-b5"
_MODEL_VERSION_SAVE = _MODEL_VERSION.split("/")[-1]
_BATCH_SIZE = 2
_DEVICE = "cuda:1"
#----------

In [4]:
# Set up a logger
log_dir = "/home/eric/srcs/FewShotSeg_Lab/FewShotVision_Lab/Segmentation_Pipes/logs"
os.makedirs(log_dir, exist_ok=True)
current_time = datetime.now()
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
log_file_path = os.path.join(log_dir, f"Version_{_VERSION}_{_MODEL_VERSION_SAVE}_{formatted_time}.log")
logger = rs_utils.setup_logger("FewShotSeg", log_file_path, level=logging.INFO)

In [5]:
json_file = '/disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/_annotations.coco.json'
with open(json_file, 'r') as f:
    data = json.load(f)
#--- 
id2label = {
    0:data['categories'][0]['name'],
    1:data['categories'][1]['name'],
    2:data['categories'][2]['name'],
    3:data['categories'][3]['name'],
    4:data['categories'][4]['name']
}
id2label[0] = "background"
label2id = {v: k for k, v in id2label.items()}

In [6]:
label2id

{'background': 0,
 'M2A1Slammer': 1,
 'M5SandstormMLRS': 2,
 'T140Angara': 3,
 'ZamakMRL': 4}

In [7]:
#----------------------
# reduce_labels should be False !! 

image_processor = SegformerImageProcessor(reduce_labels=False,size={"height": 448, "width": 448})

train_dataset = InstanceSegmentationDataset(root_dir=root_dir, image_processor=image_processor)
valid_dataset = InstanceSegmentationDataset(root_dir=root_dir, image_processor=image_processor, train=False)

train_dataloader = DataLoader(train_dataset, batch_size=_BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=_BATCH_SIZE)



In [8]:
# define model
from models.custom_segformer import FewShotFormer

seg_model = FewShotFormer.from_pretrained(_MODEL_VERSION,
                                        num_labels=5,
                                        id2label=id2label,
                                        label2id=label2id,
                                        force_download=False)

  return self.fget.__get__(instance, owner)()
Some weights of FewShotFormer were not initialized from the model checkpoint at nvidia/mit-b5 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
seg_model.config.hidden_sizes.append(1024)

In [10]:
seg_model.config.hidden_sizes

[64, 128, 320, 512, 1024]

In [11]:
seg_model.config.num_encoder_blocks = 5

In [12]:
seg_model.config.num_encoder_blocks

5

In [13]:
BACKBONE_SIZE = "large" # in ("small", "base", "large" or "giant")

backbone_archs = {
    "small": "vits14",
    "base": "vitb14",
    "large": "vitl14",
    "giant": "vitg14",
}
backbone_arch = backbone_archs[BACKBONE_SIZE]
backbone_name = f"dinov2_{backbone_arch}"

dinov2_model = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=backbone_name)
#backbone_model.to(_DEVICE)

Using cache found in /home/eric/.cache/torch/hub/facebookresearch_dinov2_main


In [14]:
dinov2_model

DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-23): 24 x NestedTensorBlock(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )


In [15]:
# model freeze 
for param in dinov2_model.parameters():
    param.requires_grad = False

In [16]:
a1 = train_dataset.__getitem__(0)['pixel_values']
a1 = a1.unsqueeze(0)
#a1 = a1.to("cuda:0")

filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/M2A1Slammer2__part_1.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/labels/M2A1Slammer2__part_1.png


In [17]:
dino_hidden_states = dinov2_model.forward_features(a1)

In [18]:
dino_hidden_states.keys()

dict_keys(['x_norm_clstoken', 'x_norm_regtokens', 'x_norm_patchtokens', 'x_prenorm', 'masks'])

In [19]:
dino_hidden_states["x_norm_patchtokens"].shape

torch.Size([1, 1024, 1024])

In [20]:
patche_token = dino_hidden_states["x_norm_patchtokens"]

In [21]:
patche_token.reshape([1,-1,32,32]).shape

torch.Size([1, 1024, 32, 32])

In [22]:
patche_token.shape

torch.Size([1, 1024, 1024])

In [23]:
# define metric
metric = evaluate.load("mean_iou")
# define optimizer
optimizer = torch.optim.AdamW(seg_model.parameters(), lr=0.00006)

In [None]:
best_val_iou = 0 


#-- model to device
seg_model = seg_model.to(_DEVICE)
dinov2_model = dinov2_model.to(_DEVICE) 

#---
for epoch in range(_EPOCHS):  # loop over the dataset multiple times
    logger.info(f"Epoch: {epoch}")

    # Training loop
    for idx, batch in enumerate(train_dataloader):
        # get the inputs
        pixel_values = batch["pixel_values"].to(_DEVICE)
        labels = batch["labels"].to(_DEVICE)

        #-----
        patch_token = dinov2_model.forward_features(pixel_values)
        patch_token = patch_token["x_norm_patchtokens"]
        patch_token = patch_token.reshape([pixel_values.shape[0],-1,32,32])

        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = seg_model(pixel_values=pixel_values, labels=labels,dino_features = patch_token)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        # Evaluate periodically
        if idx % 100 == 0:
            with torch.no_grad():
                upsampled_logits = nn.functional.interpolate(
                    logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
                )
                predicted = upsampled_logits.argmax(dim=1)

                metrics = metric._compute(
                    predictions=predicted.cpu(),
                    references=labels.cpu(),
                    num_labels=len(id2label),
                    ignore_index=255,
                    reduce_labels=False,  # we've already reduced the labels ourselves
                )
                logger.info(
                    f"Epoch: {epoch}, "
                    f"Training Loss: {loss.item():.4f}, "
                    f"Mean IoU: {metrics['mean_iou']:.4f}, "
                    f"Mean Accuracy: {metrics['mean_accuracy']:.4f}"
                )
    # Validation loop
    val_loss = 0
    val_metrics = {"mean_iou": 0, "mean_accuracy": 0}
    for batch in valid_dataloader:
        pixel_values = batch["pixel_values"].to(_DEVICE)
        labels = batch["labels"].to(_DEVICE)

        with torch.no_grad():
            #-----
            patch_token = dinov2_model.forward_features(pixel_values)
            patch_token = patch_token["x_norm_patchtokens"]
            patch_token = patch_token.reshape([pixel_values.shape[0],-1,32,32])

            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = seg_model(pixel_values=pixel_values, labels=labels,dino_features = patch_token)
            loss, logits = outputs.loss, outputs.logits
                
            #--------------            
            val_loss += outputs.loss.item()

            upsampled_logits = nn.functional.interpolate(
                logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            predicted = upsampled_logits.argmax(dim=1)

            batch_metrics = metric._compute(
                predictions=predicted.cpu(),
                references=labels.cpu(),
                num_labels=len(id2label),
                ignore_index=255,
                reduce_labels=False,
            )
            val_metrics["mean_iou"] += batch_metrics["mean_iou"]
            val_metrics["mean_accuracy"] += batch_metrics["mean_accuracy"]

    # Average validation loss and metrics
    val_loss /= len(valid_dataloader)
    val_metrics["mean_iou"] /= len(valid_dataloader)
    val_metrics["mean_accuracy"] /= len(valid_dataloader)

    logger.info(
        f"Epoch : {epoch}, "
        f"Validation Results - Loss: {val_loss:.4f}, "
        f"Mean IoU: {val_metrics['mean_iou']:.4f}, "
        f"Mean Accuracy: {val_metrics['mean_accuracy']:.4f}"
    )

    # Save the model if the validation IoU improves
    if val_metrics["mean_iou"] > best_val_iou:
        best_val_iou = val_metrics["mean_iou"]
        
        if _MODEL_SAVE:
            torch.save(
                seg_model.state_dict(),
                os.path.join(
                    "/disk3/eric/checkpoints/military_fewshot_seg",
                    f"{_VERSION}_{_MODEL_VERSION_SAVE}_segformer_best_epoch_{epoch}_miou_{best_val_iou:.4f}.pt",
                ),
            )
        logger.info("Model saved!")


2024-11-25 16:47:46 - FewShotSeg - INFO - Epoch: 0


filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/Zamak__part_7.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/labels/Zamak__part_7.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/M5Sandstorm__part_9.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/labels/M5Sandstorm__part_9.png


  acc = total_area_intersect / total_area_label
2024-11-25 16:47:47 - FewShotSeg - INFO - Epoch: 0, Training Loss: 1.5716, Mean IoU: 0.0699, Mean Accuracy: 0.1416


filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/T140Angara__part_5.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/labels/T140Angara__part_5.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/T140Angara__part_1.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/labels/T140Angara__part_1.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/T140Angara__part_21.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/labels/T140Angara__part_21.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/T140Angara__part_20.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/labels/T140Angara__part_20.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmentation_pipe/train/images/M5Sandstorm__part_17.png
filename /disk3/eric/dataset/VISION_SOFS/WEAPON_5/segmenta

AttributeError: 'NoneType' object has no attribute 'shape'