In [None]:
import torch
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms as T
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

#LOAD MODEL
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
anchor_generator = AnchorGenerator(sizes=((16,), (32,), (64,), (128,), (256,)), aspect_ratios=((0.5, 1.0, 2.0),) * 5)
model.rpn.anchor_generator = anchor_generator
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 5)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, 5)
model.rpn.post_nms_top_n_train = 2000
model.rpn.post_nms_top_n_test = 2000
model.roi_heads.detections_per_img = 2000


In [None]:
#LOAD EXISITING SAVED MODEL
model_path = 'MODELPATH'  #update path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
saved_state_dict = torch.load(model_path)
model.load_state_dict(saved_state_dict)
model.to(device)

In [None]:
#LOAD DATASET

class CustDat(torch.utils.data.Dataset):
    def __init__(self, image_names, images_directory, masks_directory):
        """
        Initializes the dataset.

        :param image_names: A list of image file names (e.g., ['0.png', '1.png', ...])
        :param images_directory: The directory where image files are located
        :param masks_directory: The directory where mask files are located
        """
        self.image_names = [name for name in image_names if os.path.isfile(os.path.join(images_directory, name)) and not name.startswith('.')]
        self.images_directory = images_directory
        self.masks_directory = masks_directory

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.images_directory, img_name)
        img = Image.open(img_path).convert("RGB")

        # Extract the base index of the image to match with masks
        base_idx = img_name.split('.')[0]

        # Loading separate mask images according to the naming convention
        mask_files = [file for file in os.listdir(self.masks_directory) if file.startswith(f"{base_idx}_") and file.endswith('.png')]
        masks = []
        labels = []
        for mask_file in mask_files:
            mask_path = os.path.join(self.masks_directory, mask_file)
            mask = np.array(Image.open(mask_path))
            masks.append(mask > 0)


            last_char = mask_file[-5]
            if last_char == 'g':
                labels.append(2)
            elif last_char == 'o':
                labels.append(3)
            elif last_char == 'l':
                labels.append(4)
            else:
                labels.append(1)
 

        num_objs = len(masks)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.stack([torch.tensor(mask, dtype=torch.uint8) for mask in masks])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks

        return T.ToTensor()(img), target

    def __len__(self):
        return len(self.image_names)


images = sorted(os.listdir("genimages/"))
transform = T.ToTensor()
def custom_collate(data):
  return data

num = int(0.8 * len(images))
num = num if num % 2 == 0 else num + 1
train_imgs_inds = np.random.choice(range(len(images)) , num , replace = False)
val_imgs_inds = np.setdiff1d(range(len(images)) , train_imgs_inds)
train_imgs = np.array(images)[train_imgs_inds]
val_imgs = np.array(images)[val_imgs_inds]


train_dl = torch.utils.data.DataLoader(CustDat(train_imgs ,'genimages/','masks/') ,
                                 batch_size = 1,
                                 shuffle = True ,
                                 collate_fn = custom_collate ,
                                 num_workers = 8,
                                 pin_memory = True)

val_dl = torch.utils.data.DataLoader(CustDat(val_imgs ,'genimages/','masks/') ,
                                 batch_size = 1,
                                 shuffle = True ,
                                 collate_fn = custom_collate ,
                                 num_workers = 8,
                                 pin_memory = True)


In [None]:
#TRAINING
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.0005, momentum=0.8, weight_decay=0.0005)
all_train_losses = []
all_val_losses = []

# Define the learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

NUM_EPOCHS = 50

#TRAINING
for epoch in range(0,NUM_EPOCHS):
    train_epoch_loss = 0
    val_epoch_loss = 0
    model.train()
    
    #TRAIN LOSS
    for i , dt in enumerate(train_dl):
        imgs = [dt[i][0].to(device) for i in range(len(dt))]
        targ = [dt[i][1] for i in range(len(dt))]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targ]
        loss = model(imgs , targets)
        losses = sum([l for l in loss.values()])
        train_epoch_loss += losses.cpu().detach().numpy()
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    all_train_losses.append(train_epoch_loss)

    #TEST LOSS
    with torch.no_grad():
        for j , dt in enumerate(val_dl):
            imgs = [dt[i][0].to(device) for i in range(len(dt))]
            targ = [dt[i][1] for i in range(len(dt))]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targ]
            loss = model(imgs , targets)
            losses = sum([l for l in loss.values()])
            val_epoch_loss += losses.cpu().detach().numpy()
        all_val_losses.append(val_epoch_loss)

    # Update the learning rate
    scheduler.step()

    print(epoch , "  " , train_epoch_loss , "  " , val_epoch_loss)

torch.save(model.state_dict(), 'MODELPATH')  #update path
