In [1]:
import pathlib
import torch
import torch.utils.data
from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2
import numpy as np
from matplotlib import pyplot as plt


USE_GPU = True
dtype = torch.float32

device = torch.device('cuda:0' if USE_GPU and torch.cuda.is_available() else 'cpu')
print(device)

transforms = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        ])

images_path = '/vol/bitbucket/ajm223/SWE_GP/data/sprint_5_aug_dataset_train'
annotations_path = images_path + '/result.json'

train_dataset = datasets.CocoDetection(images_path, annotations_path, transforms=transforms)
train_dataset = datasets.wrap_dataset_for_transforms_v2(train_dataset, target_keys=["boxes", "labels", "masks"])

images_path = '/vol/bitbucket/ajm223/SWE_GP/data/sprint_5_aug_dataset_val'
annotations_path = images_path + '/result.json'

val_dataset = datasets.CocoDetection(images_path, annotations_path, transforms=transforms)
val_dataset = datasets.wrap_dataset_for_transforms_v2(val_dataset, target_keys=["boxes", "labels", "masks"])


batch_size = 20 # Adjust if memory is an issue

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, collate_fn=lambda batch: tuple(zip(*batch)),)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, collate_fn=lambda batch: tuple(zip(*batch)),)

cuda:0
loading annotations into memory...
Done (t=15.26s)
creating index...
index created!
loading annotations into memory...
Done (t=4.02s)
creating index...
index created!


In [2]:
# Use the priority order of the objects to only show the parts of the object that count as that region

import pathlib
import torch
import torch.utils.data
from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2
import numpy as np
from matplotlib import pyplot as plt

def get_anti_mask(image: torch.tensor, masks: torch.tensor, labels: torch.tensor, label_order: torch.tensor) -> torch.tensor:
    total_masks = torch.zeros(size=(len(label_order), image.shape[1], image.shape[2]))
    for i in range(len(labels)):
        j = labels[i]
        total_masks[j,:,:] = total_masks[j,:,:].bool() | masks[i,:,:].bool()
        
    total_masks = total_masks
    out_masks = torch.zeros_like(total_masks)
    #fig, ax = plt.subplots(1, len(label_order), sharex=True, sharey=True, dpi=200)
    for i in range(len(label_order)):
        order_ind = label_order.tolist().index(i)
        #print(i, label_order[order_ind+1:])
        out_masks[i,:,:] = torch.any(total_masks[label_order[order_ind+1:],:,:].view(-1, image.shape[1], image.shape[2]).bool(),dim=0)
        #ax[i].imshow(out_masks[i,:,:])
        #ax[i].axis('off')
    #plt.show()
    return out_masks

def remove_overlap(image: torch.tensor, masks: torch.tensor, labels: torch.tensor, label_order, display=False) -> torch.tensor:
    anti_masks = get_anti_mask(image, masks, labels, label_order)
    masks_out = torch.zeros_like(masks)
    for i in range(len(masks)):
        masks_out[i,:,:] = masks[i,:,:].bool() & ~anti_masks[labels[i],:,:].bool()
        if display:
            print(["Seed Coat", "Interior", "Endosperm", "Void"][labels[i]])
            fig, ax = plt.subplots(1, 4, sharex=True, sharey=True, dpi=200)
            
            ax[0].imshow(masks[i,:,:])
            ax[0].set_title("Mask")
            ax[0].axis('off')
            
            ax[1].imshow(anti_masks[labels[i],:,:])
            ax[1].set_title("Anti-Mask")
            ax[1].axis('off')
            
            ax[2].imshow(masks_out[i,:,:])
            ax[2].set_title("Output")
            ax[2].axis('off')

            ax[3].imshow(image[2,:,:])
            ax[3].set_title("Image")
            ax[3].axis('off')
            
            plt.show()
    return masks_out

def process(images, targets, display=False):
    out_images, out_targets = [], []
    label_order = torch.tensor([0, 1, 2, 3])
    for i in range(len(images)):
        out_images.append(images[i].to(device))
        out_targets.append({
            'boxes':targets[i]['boxes'].to(device),
            'masks':remove_overlap(images[i], targets[i]['masks'], targets[i]['labels'], label_order, display=display).to(device),
            'labels':targets[i]['labels'].to(device)
        })
    return tuple(out_images), tuple(out_targets)

import matplotlib.patches as patches
import random
def show_progress(image, target, num_masks=1):
    fig, ax = plt.subplots(2,1+num_masks,sharex=True, sharey=True, dpi=300, figsize=(3*(1+num_masks),6))

    cmap_im = None
    cmap_mask = 'jet'

    # display original image
    im = image.detach().cpu().numpy().transpose(1, 2, 0)
    im = (im - im.min()) / (im.max() - im.min())

    ax[0,0].imshow(im,cmap=cmap_im)
    #ax[0,0].imshow(im*0,cmap=cmap_mask, alpha=0.5)
    ax[1,0].set_title('Original')
    
    ax[1,0].imshow(im,cmap=cmap_im)
    #ax[1,0].imshow(im*0,cmap=cmap_mask, alpha=0.5)
    ax[0,0].set_title('BBoxes')
    
    with torch.no_grad():
        model.eval()
        output = model(image.unsqueeze(0).to(device))[0]
        model.train()
        
    for j in range(num_masks):
        i = random.randint(0, len(output['masks'])-1)
        maps = output['masks'][i,:,:].detach().cpu().squeeze(0).numpy()
        ax[0,j+1].imshow(im,cmap=cmap_im)
        ax[0,j+1].imshow(maps,cmap=cmap_mask, alpha=0.5)
        
        label = output['labels'][i].item()
        score = output['scores'][i].item()
        class_ = ["Seed Coat", "Interior", "Endosperm", "Void"][label]
        
        ax[0,j+1].set_title(f"Pred: {class_} - {score*100:3.1f}%")

    for j in range(num_masks):
        i = random.randint(0, len(target['masks'][:,0,0])-1)
        maps = target['masks'][i,:,:].detach().cpu().squeeze(0).numpy()
        ax[1,j+1].imshow(im,cmap=cmap_im)
        ax[1,j+1].imshow(maps,cmap=cmap_mask, alpha=0.5)
        
        label = target['labels'][i].item()
        #score = target['l'][i].item()
        class_ = ["Seed Coat", "Interior", "Endosperm", "Void"][label]
        
        ax[1,j+1].set_title(f"Target: {class_}")

    for j in range(len(output['boxes'])):
        x1, y1, x2, y2 = output['boxes'][j].detach().cpu().numpy()
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        ax[0,0].add_patch(rect) 

    for j in range(len(target['boxes'])):
        x1, y1, x2, y2 = target['boxes'][j].detach().cpu().numpy()
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='w', facecolor='none')
        ax[1,0].add_patch(rect) 
    
    for i in range(2):
        for j in range(1+num_masks):
            ax[i,j].axis('off')
    plt.show()

from torchmetrics.detection.mean_ap import MeanAveragePrecision as mAP
def eval_map(preds, targets):
    metric = mAP().to(device)
    targets = NestedTensorHandler.get_structure_on_device(list(targets), device)
    preds = NestedTensorHandler.get_structure_on_device(preds, device)
    output = metric(preds, targets)
    return output['map']

class NestedTensorHandler:
    @staticmethod
    def to_device(item, device):
        """Recursively send tensors to the specified device in the nested structure."""
        if isinstance(item, torch.Tensor):
            # Move tensor to the specified device
            return item.to(device)
        elif isinstance(item, dict):
            # Recursively process dictionary items
            return {k: NestedTensorHandler.to_device(v, device) for k, v in item.items()}
        elif isinstance(item, list):
            # Recursively process list items
            return [NestedTensorHandler.to_device(i, device) for i in item]
        elif isinstance(item, tuple):
            # Recursively process tuple items and convert it back to tuple
            return tuple(NestedTensorHandler.to_device(i, device) for i in item)
        else:
            # Return the item as is if it's not a tensor, list, dict, or tuple
            return item

    @staticmethod
    def get_structure_on_device(nested_structure, device='cpu'):
        """Return the nested structure with tensors moved to the specified device."""
        return NestedTensorHandler.to_device(nested_structure, device)

In [3]:
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_model(num_classes, model_path=None):
    # Load a pre-trained model for classification and return
    # only the features
    model = maskrcnn_resnet50_fpn(weights='DEFAULT')
    
    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # And replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    if not model_path is None:
        model.load_state_dict(torch.load(model_path))
    return model

cuda:0


In [4]:
load_from_checkpoint = True

model_path = 'model.pt'
num_classes = len(['Seed','Interior','Endosperm','Void']) + 1
model = get_model(num_classes).to(device)
if load_from_checkpoint:
    model.load_state_dict(torch.load(model_path))

In [9]:
images_val, targets_val = next(iter(val_loader))
image_val, target_val = process(images_val, targets_val)
#show_progress(image_val[0], target_val[0], 4)

In [None]:
from tqdm.notebook import tqdm
import gc

optimizer = torch.optim.Adamax(model.parameters(), lr=1e-4)

num_epochs = 1 # Set the number of epochs
batch_size = 20 # Adjust if memory is an issue

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, collate_fn=lambda batch: tuple(zip(*batch)),)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=10, collate_fn=lambda batch: tuple(zip(*batch)),)

print_every = len(train_loader)//5

model.to(device)
losses = []

for epoch in tqdm(range(num_epochs), total=num_epochs):
    for i, (images, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
            model.train()
            images, targets = process(images, targets)

            loss_dict = model(images, targets)
            loss = sum(tuple([loss for loss in loss_dict.values()]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            gc.collect()
            if i % print_every == print_every-1 or i == 0:
                images_val, targets_val = next(iter(val_loader))
                images_val, target_val = process(images_val, targets_val)
                model.eval()
                preds_val = model(images_val)
                val_mAP = eval_map(preds_val, targets_val)
                
                preds_train = model(images)
                train_mAP = eval_map(preds_train, targets)
                tqdm.write(f"Epoch {epoch} | Iteration {i} | Loss {loss.item():1.3f} | Train mAP: {train_mAP.item():1.3f} | Val mAP: {val_mAP.item():1.3f}")

                #show_progress(image_val[0], target_val[0], 4)
                
                
load_from_checkpoint = True
torch.save(model.state_dict(), model_path)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/9800 [00:00<?, ?it/s]

Epoch 0 | Iteration 0 | Loss 0.3520944118499756 | Train mAP: 0.5038595199584961 | Val mAP: 0.4518515169620514


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f0c4cfe1cc0>>
Traceback (most recent call last):
  File "/vol/bitbucket/or623/virtual_envs/nlp_env/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [None]:
images_val, targets_val = next(iter(val_loader))
images_val, targets_val = process(images_val, targets_val)

In [None]:
with torch.no_grad():
    output = model(images_val, targets_val)

In [None]:
def show_progress(image, target, num_masks=1):
    fig, ax = plt.subplots(2,1+num_masks,sharex=True, sharey=True, dpi=300, figsize=(3*(1+num_masks),6))

    cmap_im = None
    cmap_mask = 'jet'

    # display original image
    im = image.detach().cpu().numpy().transpose(1, 2, 0)
    im = (im - im.min()) / (im.max() - im.min())

    ax[0,0].imshow(im,cmap=cmap_im)
    #ax[0,0].imshow(im*0,cmap=cmap_mask, alpha=0.5)
    ax[1,0].set_title('Original')
    
    ax[1,0].imshow(im,cmap=cmap_im)
    #ax[1,0].imshow(im*0,cmap=cmap_mask, alpha=0.5)
    ax[0,0].set_title('BBoxes')
    
    with torch.no_grad():
        model.eval()
        output = model(image.unsqueeze(0).to(device))[0]
        model.train()
        
    for j in range(num_masks):
        i = random.randint(0, len(output['masks'])-1)
        maps = output['masks'][i,:,:].detach().cpu().squeeze(0).numpy()
        ax[0,j+1].imshow(im,cmap=cmap_im)
        ax[0,j+1].imshow(maps,cmap=cmap_mask, alpha=0.5)
        
        label = output['labels'][i].item()
        score = output['scores'][i].item()
        class_ = ["Seed Coat", "Interior", "Endosperm", "Void"][label]
        
        ax[0,j+1].set_title(f"Pred: {class_} - {score*100:3.1f}%")

    for j in range(num_masks):
        i = random.randint(0, len(target['masks'][:,0,0])-1)
        maps = target['masks'][i,:,:].detach().cpu().squeeze(0).numpy()
        ax[1,j+1].imshow(im,cmap=cmap_im)
        ax[1,j+1].imshow(maps,cmap=cmap_mask, alpha=0.5)
        
        label = target['labels'][i].item()
        #score = target['l'][i].item()
        class_ = ["Seed Coat", "Interior", "Endosperm", "Void"][label]
        
        ax[1,j+1].set_title(f"Target: {class_}")

    for j in range(len(output['boxes'])):
        x1, y1, x2, y2 = output['boxes'][j].detach().cpu().numpy()
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        ax[0,0].add_patch(rect) 

    for j in range(len(target['boxes'])):
        x1, y1, x2, y2 = target['boxes'][j].detach().cpu().numpy()
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='w', facecolor='none')
        ax[1,0].add_patch(rect) 
    
    for i in range(2):
        for j in range(1+num_masks):
            ax[i,j].axis('off')
    plt.show()

In [None]:
images_val, targets_val = next(iter(val_loader))
image_val, target_val = process(images_val, targets_val)
show_progress(image_val[0], target_val[0], 4)