# Salmon Dataset Class

In [58]:
import torch
import os, json
import numpy as np
from torchvision.io import read_image
from torchvision import tv_tensors

class SalmonDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        
        self.imgs = [file for file in sorted(os.listdir(os.path.join(root, "Images"))) if file.endswith(('.jpg', '.jpeg', '.png'))]
        self.annots = [file for file in sorted(os.listdir(os.path.join(root, "Boxes"))) if file.endswith('.json')]
    
    def __getitem__(self, idx):
        # load images and masks
        img_path = os.path.join(self.root, "Images", self.imgs[idx])
        annots_path = os.path.join(self.root, "Boxes", self.annots[idx])
        
        img = read_image(img_path)
        if "DS" not in annots_path:
            try:
                shapes = json.load(open(annots_path))['shapes']
            except:
                print("Cannot open json file...")
            
        boxes = [shape['points'] for shape in shapes]
        boxes_flattened = []
        for box in boxes:
            box = np.array(box).flatten().tolist()
            boxes_flattened.append(box)
        
        boxes_flattened = torch.tensor(boxes_flattened, dtype=torch.float)
        

        # Number of boxes in image
        num_objs = len(boxes_flattened)

        # there is only one class -> Salmon = 1
        labels = torch.ones((num_objs,), dtype=torch.int64)

        # Wrap sample and targets into torchvision tv_tensors:
        img = tv_tensors.Image(img)
        img = img.float() / 255.0
        

        target = {}
        target["boxes"] = boxes_flattened
        target["labels"] = labels
        target["image_id"] = idx
        target["area"] = (boxes_flattened[:,2]-boxes_flattened[:,0])*(boxes_flattened[:,3]-boxes_flattened[:,1])
        target["iscrowd"] = torch.zeros((num_objs,), dtype=torch.int64)

        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        if target["boxes"].numel() == 0:
            print("No boxes :(")

        return img, target

    def __len__(self):
        return len(self.imgs)

# Get Faster RCNN model pretrained on COCO

In [59]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights

def get_detection_model(num_classes, weights=FasterRCNN_ResNet50_FPN_Weights):
    # load a model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

# Create Training, Validation and Test sets

In [60]:
import torch
import random
import numpy as np

# Random seeds for reproducibility
g = torch.manual_seed(0)
random.seed(0)

data_path = "/Users/magnuswiik/prosjektoppgave_data/Masteroppgave_data/Helfisk_Deteksjonssett/"

# use our dataset and defined transformations
dataset = SalmonDataset(data_path)
dataset_validation = SalmonDataset(data_path)
dataset_test = SalmonDataset(data_path)

data_indices = np.arange(0,len(dataset.imgs)*0.2, dtype=np.int16).tolist() # TODO: Fjern *0.2

indices_test = random.sample(data_indices, int(len(data_indices)*0.2))
data_indices = [idx for idx in data_indices if idx not in indices_test]

indices_validation = random.sample(data_indices, int(len(data_indices)*0.2))
data_indices = [idx for idx in data_indices if idx not in indices_validation]

indices_training = random.sample(data_indices, int(len(data_indices)))

# split the dataset in train and test set
dataset_training = torch.utils.data.Subset(dataset, indices_training) # 80% for training and validation
dataset_validation = torch.utils.data.Subset(dataset_validation, indices_validation)
dataset_test = torch.utils.data.Subset(dataset_test, indices_test) # 20% for testing

# Helper functions for training 1

In [61]:
import numpy as np
import torch
import matplotlib.pyplot as plt
plt.style.use('ggplot')


# this class keeps track of the training and validation loss values...
# ... and helps to get the average for each epoch as well
class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0
        
    def send(self, value):
        self.current_total += value
        self.iterations += 1
    
    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations
    
    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0
        
        
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model, optimizer
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, 'outputs/best_model.pth')

# Helper functions for training 2

In [62]:
def collate_fn(batch):
    """
    To handle the data loading as different images may have different number 
    of objects and to handle varying size tensors as well.
    """
    return tuple(zip(*batch))

def save_model(epoch, model, optimizer):
    """
    Function to save the trained model till current epoch, or whenver called
    """
    torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                }, 'outputs/last_model.pt')
    
    
def save_loss_plot(OUT_DIR, train_loss, val_loss):
    figure_1, train_ax = plt.subplots()
    figure_2, valid_ax = plt.subplots()
    train_ax.plot(train_loss, color='tab:blue')
    train_ax.set_xlabel('iterations')
    train_ax.set_ylabel('train loss')
    valid_ax.plot(val_loss, color='tab:red')
    valid_ax.set_xlabel('iterations')
    valid_ax.set_ylabel('validation loss')
    figure_1.savefig(f"{OUT_DIR}/train_loss.png")
    figure_2.savefig(f"{OUT_DIR}/valid_loss.png")
    print('SAVING PLOTS COMPLETE...')
    plt.close('all')

# Training Detection Model

In [63]:
# Train and evaluate the model
import utils
import os, json
from tqdm.auto import tqdm
import torch
import matplotlib.pyplot as plt
import pandas as pd
plt.style.use('ggplot')

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2


# define training and validation data loaders
data_loader_training = torch.utils.data.DataLoader(
    dataset_training,
    batch_size=2,
    shuffle=True,
    num_workers=0,
    collate_fn=utils.collate_fn,
    generator=g
)

# define training and validation data loaders
data_loader_validation = torch.utils.data.DataLoader(
    dataset_validation,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    collate_fn=utils.collate_fn,
    generator=g
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=utils.collate_fn,
    generator=g
)

# get the model using our helper function
model = get_detection_model(num_classes)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

### Training

num_epochs = 1

train_loss_hist = Averager()
train_loss_list = []
lr_step_sizes = []
validation_losses = []


for epoch in range(num_epochs):
    
    # initialize tqdm progress bar
    prog_bar = tqdm(data_loader_training, total=len(data_loader_training))
    
    train_loss_per_epoch = []

    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data
        
        images = list(image.to(device) for image in images)
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        train_loss_per_epoch.append(loss_value)
        losses.backward()
        optimizer.step()
        lr_scheduler.step()

        # update the loss value beside the progress bar for each iteration
        prog_bar.set_description(desc=f"|Epoch: {epoch+1}/{num_epochs}| Loss: {loss_value:.4f}")
    
    validation_loss = 0.0
    with torch.no_grad():
        for images, targets in data_loader_validation:
            images = [image.to(device) for image in images]
            loss_dict = model(images, targets)
            loss = sum(loss for loss in loss_dict.values())
            validation_loss += loss.item()
    
    
    # Save metrics per epoch
    lr_step_sizes.append(lr_scheduler.get_last_lr()[0])
    train_loss_list.append(sum(loss for loss in train_loss_per_epoch)/len(train_loss_per_epoch))
    train_loss_hist.send(sum(loss for loss in train_loss_per_epoch)/len(train_loss_per_epoch))
    validation_loss /= len(data_loader_validation)
    validation_losses.append(validation_loss)


### SAVING RESULTS

MODELPATH = "models/" + "model1/"

if not os.path.exists(MODELPATH):
    os.mkdir(MODELPATH)
    
dict = {'training_loss': train_loss_list, 'lr_step_size': lr_step_sizes, 'validation_losses': validation_losses}
df = pd.DataFrame(dict)
df.to_csv(MODELPATH + 'metrics.csv', index=False)

torch.save(model.state_dict(), MODELPATH + "model1.pt")
print("Model is saved at:" + MODELPATH + "model1.pt")


|Epoch: 1/1| Loss: 0.3208: 100%|██████████| 22/22 [01:27<00:00,  3.97s/it]


Model is saved at:models/model1/model1.pt
