In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import albumentations as A
import cv2
import numpy as np
import skimage as ski

import matplotlib.pyplot as plt
import os
import copy

from tqdm import tqdm
from IPython.display import clear_output

import psutil
import pynvml

import sys
sys.path.append('/home/meribejayson/Desktop/Projects/SharkCNN/training_models/dataloaders/')

from sharkdataset import SharkDatasetTrain as SharkDataset

In [2]:
torch.manual_seed(12)

if not torch.cuda.is_available():
    raise Exception("Couldn't find CUDA")

device = torch.device("cuda")

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)

megaset_path = "/home/meribejayson/Desktop/Projects/SharkCNN/datasets-reduced/megaset/"
megaset_train_images_path = "/home/meribejayson/Desktop/Projects/SharkCNN/datasets-reduced/megaset/train/images/"
megaset_train_labels_path = "/home/meribejayson/Desktop/Projects/SharkCNN/datasets-reduced/megaset/train/labels/"

image_width = 1920
image_height = 1080

In [3]:
class LogisticRegresion(nn.Module):

    def __init__(self, input_size):
        super().__init__()
        
        self.linear = nn.Linear(input_size,1)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.linear(x)
        
        return self.sig(x)

In [4]:
shark_dataset = SharkDataset()
data_loader = data.DataLoader(shark_dataset, batch_size=1_000_000, num_workers=5)

In [5]:
state_dict = torch.load("./train-final-2/lr_weights_train_2.tar")

In [6]:
model = LogisticRegresion(85)
model.load_state_dict(state_dict)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

In [7]:
LARGE_NUM = 2e120
target_loss_change = 1e-5
exps_in_iter = (image_height * image_width * 2)
kappa = 1 / 323
kappa_inv = 323
coef = (1 + kappa) / 2
small_peturb = 2e-120

def train_model(model, optimizer, data_loader):
    model.train()
    last_average_loss = LARGE_NUM
    curr_average_loss = 0
    curr_iter = 1

    while(np.abs(curr_average_loss - last_average_loss) > target_loss_change):
        
        total_iter_avg_loss = 0

        for point in data_loader:
            data_inputs = point[:, :-1].to(device).float()
            data_labels = point[:, -1].to(device).float()

            preds = model(data_inputs)
            preds = preds.squeeze(dim=1)

            weights = torch.clone(data_labels)
            weights[data_labels == 0.0] = 1
            weights[data_labels == 1.0] = kappa_inv

            weights = coef * weights
            
            data_labels[data_labels == 0.0] = small_peturb
            data_labels[data_labels == 1.0] = 1.0 - small_peturb
            
            loss_module = nn.BCELoss(weight=weights)
            loss = loss_module(preds, data_labels.float())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_iter_avg_loss += loss.item()
 
        last_average_loss = curr_average_loss
        curr_average_loss = total_iter_avg_loss
        
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        clear_output(wait=True)
        print(f'Current iteration: {curr_iter - 1}, Average Loss: {last_average_loss}')
        print(f'Current iteration: {curr_iter}, Average Loss: {curr_average_loss}')
        print(f"CPU Usage: {psutil.cpu_percent()}% GPU memory usage: {int(info.used / info.total)}% \n")

        print("Current Parameters:")
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(name, param.data)

        curr_iter += 1
        



In [8]:
train_model(model, optimizer, data_loader)

KeyboardInterrupt: 

In [None]:
pynvml.nvmlShutdown()
state_dict = model.state_dict()
print(state_dict)
torch.save(state_dict, "lr_weights_train_3.tar")

OrderedDict([('linear.weight', tensor([[ 8.8438e-02, -4.1447e-02,  5.7988e-02,  2.5508e-01,  2.1914e-01,
          4.1002e-02,  1.2102e-01,  2.4286e-02,  2.7373e-01,  5.5927e-02,
          1.5209e-02, -7.9003e-02, -1.0960e-02, -1.7376e-01, -2.4090e-01,
         -1.8512e-01, -2.1791e-01, -2.7587e-01, -1.2228e-01, -1.0727e-01,
         -1.2169e-01, -1.4348e-01, -1.3062e-01, -1.9141e-01, -2.1120e-01,
         -2.9247e-01, -1.2279e-01, -1.0668e-01, -1.2639e-01, -1.8587e-01,
         -2.6121e-01, -1.7965e-01, -2.2002e-01, -2.6013e-01, -1.2802e-01,
         -1.1201e-01, -1.4598e-01,  1.7382e-03,  3.8620e-03, -4.4278e-03,
          1.7996e-02, -1.1002e-02,  2.4375e-03,  5.7739e-03, -1.1377e-03,
          8.4366e-03,  8.9391e-03,  3.1053e-02,  2.5116e-02,  2.1628e-02,
          8.6249e-03,  2.6799e-02,  4.3937e-03, -3.9218e-03, -9.3392e-04,
         -3.5854e-02,  1.2282e-02, -4.6147e-02, -1.9868e-03, -1.0342e-02,
         -6.3967e-03,  4.6510e-03,  1.6078e-03, -1.2351e-03, -5.8392e-03,
       