In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


# Import Libraries

In [None]:
import os
import cv2
import copy
import random
import argparse
import numpy as np

import torch
import torch.optim as optim

from network import MyNet
from torchsummary import summary
from dataset import CropData, LoadImages
from torch.utils.data import DataLoader, Dataset

from torchvision.models import resnet50, resnet101, resnet18

# TL: No weights

In [None]:
resnet_18 = resnet18(pretrained=False)
resnet_18 = torch.nn.Sequential(*list(resnet_18.children())[:-2])

In [None]:
IMG_PATH = "//content/drive/MyDrive/Data-20220522T105527Z-002/Rye"

In [None]:
crop_images_array = LoadImages(IMG_PATH, dsize=(224, 224)).load_images_into_array()

crop_dataset = CropData(
    images=crop_images_array
    )


Images loaded into array:  (315, 224, 224, 3)


In [None]:
crop_dataloader = DataLoader(crop_dataset, batch_size=32, shuffle=True)

In [None]:
net = MyNet(3, 100, 3)

In [None]:
summary(resnet_18, input_size=(3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [None]:
summary(net, input_size=(3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1        [-1, 100, 224, 224]           2,800
       BatchNorm2d-2        [-1, 100, 224, 224]             200
            Conv2d-3        [-1, 100, 224, 224]          90,100
       BatchNorm2d-4        [-1, 100, 224, 224]             200
            Conv2d-5        [-1, 100, 224, 224]          90,100
       BatchNorm2d-6        [-1, 100, 224, 224]             200
            Conv2d-7        [-1, 100, 224, 224]          10,100
       BatchNorm2d-8        [-1, 100, 224, 224]             200
Total params: 193,900
Trainable params: 193,900
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 306.25
Params size (MB): 0.74
Estimated Total Size (MB): 307.56
----------------------------------------------------------------


In [None]:
device = ("cuda:0" if torch.cuda.is_available() else "cpu")

# model = MyNet(input_dim=3, nChannel=args.nChannel, nConv=args.nConv)

model = resnet_18.train()
# device = "cpu"
model = resnet_18.to(device)

# similarity loss definition
loss_fn = torch.nn.CrossEntropyLoss()

# scribble loss definition
loss_fn_scr = torch.nn.CrossEntropyLoss()

# continuity loss definition
loss_hpy = torch.nn.L1Loss(size_average = True)
loss_hpz = torch.nn.L1Loss(size_average = True)

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
stepsize_sim = 0.5; stepsize_con = 1



In [None]:
def train(data_loader, epochs, minLabels):
  for batch_idx in range(1, epochs+1):
    min_loss = 0.0
    for X in data_loader:
      X = X.to(device)
      optimizer.zero_grad()
      output = model( X )
      target = output.data
      output = output.permute( 0, 2, 3, 1).contiguous().view( -1, output.shape[1] )
 
      HPy_target = torch.zeros(target.shape[0], target.shape[2]-1, target.shape[3], target.shape[1]).to(device)
      HPz_target = torch.zeros(target.shape[0], target.shape[2], target.shape[3]-1, target.shape[1]).to(device)

      outputHP = output.reshape( (target.shape[0], target.shape[2], target.shape[3], target.shape[1]) )
      HPy = outputHP[:, 1:, :, :] - outputHP[:, 0:-1, :, :]
      HPz = outputHP[:, :, 1:, :] - outputHP[:, :, 0:-1, :]

      HPy = HPy.to(device); HPz = HPz.to(device)

      lhpy = loss_hpy(HPy, HPy_target)
      lhpz = loss_hpz(HPz, HPz_target)

      _, target = torch.max( output, 1 )

      im_target = target.cpu().numpy()
      nLabels = len(np.unique(im_target))
            
      loss = stepsize_sim * loss_fn(output, target) + stepsize_con * (lhpy + lhpz)
                
      loss.backward()
      optimizer.step()

      if loss < min_loss:
        min_loss = loss
        best_model_weights = copy.deepcopy(model.state_dict())

    print (batch_idx, '/', epochs, '|', ' label num :', nLabels, ' | loss :', loss.item())

    if nLabels <= minLabels:
        print ("nLabels", nLabels, "reached minLabels", minLabels, ".")
        break

In [None]:
train(crop_dataloader, 100, 8)  #Batch size = 32 learning_rate =0.05

1 / 100 |  label num : 350  | loss : 2.932061195373535
2 / 100 |  label num : 133  | loss : 2.0080976486206055
3 / 100 |  label num : 49  | loss : 1.1151894330978394
4 / 100 |  label num : 36  | loss : 0.6813901662826538
5 / 100 |  label num : 32  | loss : 0.546470582485199
6 / 100 |  label num : 31  | loss : 0.4888755977153778
7 / 100 |  label num : 32  | loss : 0.388486385345459
8 / 100 |  label num : 32  | loss : 0.4287705421447754
9 / 100 |  label num : 32  | loss : 0.3381318151950836
10 / 100 |  label num : 30  | loss : 0.33792275190353394
11 / 100 |  label num : 30  | loss : 0.29082778096199036
12 / 100 |  label num : 32  | loss : 0.2973807454109192
13 / 100 |  label num : 32  | loss : 0.26602986454963684
14 / 100 |  label num : 30  | loss : 0.26175472140312195
15 / 100 |  label num : 31  | loss : 0.24638856947422028
16 / 100 |  label num : 32  | loss : 0.2344052791595459
17 / 100 |  label num : 30  | loss : 0.2376026213169098
18 / 100 |  label num : 30  | loss : 0.22645655274391

In [None]:
resnet_18 = resnet18(pretrained=False)
resnet_18 = torch.nn.Sequential(*list(resnet_18.children())[:-2])
model = resnet_18.train()
model = resnet_18.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

crop_dataloader = DataLoader(crop_dataset, batch_size=16, shuffle=True) 

In [None]:
train(crop_dataloader, 100, 8)  #Batch size = 16 learning_rate =0.1

1 / 100 |  label num : 27  | loss : 0.6706045866012573
2 / 100 |  label num : 25  | loss : 0.25471556186676025
3 / 100 |  label num : 18  | loss : 0.2398880422115326
4 / 100 |  label num : 17  | loss : 0.29093819856643677
5 / 100 |  label num : 16  | loss : 0.1495736837387085
6 / 100 |  label num : 16  | loss : 0.14788992702960968
7 / 100 |  label num : 15  | loss : 0.13239479064941406
8 / 100 |  label num : 14  | loss : 0.14970757067203522
9 / 100 |  label num : 16  | loss : 0.20074601471424103
10 / 100 |  label num : 15  | loss : 0.14906485378742218
11 / 100 |  label num : 14  | loss : 0.11265918612480164
12 / 100 |  label num : 14  | loss : 0.13614700734615326
13 / 100 |  label num : 14  | loss : 0.13541443645954132
14 / 100 |  label num : 14  | loss : 0.14484724402427673
15 / 100 |  label num : 14  | loss : 0.1350121945142746
16 / 100 |  label num : 14  | loss : 0.1564665287733078
17 / 100 |  label num : 14  | loss : 0.08787835389375687
18 / 100 |  label num : 13  | loss : 0.085669

# Use Resnet pretrained Weights

In [None]:
resnet_18 = resnet18(pretrained=True)
resnet_18 = torch.nn.Sequential(*list(resnet_18.children())[:-2])

model = resnet_18.train()
model = resnet_18.to(device)
# optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.01)

crop_dataloader = DataLoader(crop_dataset, batch_size=16, shuffle=True)

In [None]:
train(crop_dataloader, 100, 8)

1 / 100 |  label num : 36  | loss : 0.377993106842041
2 / 100 |  label num : 31  | loss : 0.21667860448360443
3 / 100 |  label num : 26  | loss : 0.18946543335914612
4 / 100 |  label num : 22  | loss : 0.13776540756225586
5 / 100 |  label num : 18  | loss : 0.10034201294183731
6 / 100 |  label num : 18  | loss : 0.11784830689430237
7 / 100 |  label num : 16  | loss : 0.06966164708137512
8 / 100 |  label num : 17  | loss : 0.07500705122947693
9 / 100 |  label num : 14  | loss : 0.07832532376050949
10 / 100 |  label num : 15  | loss : 0.0581207200884819
11 / 100 |  label num : 14  | loss : 0.0528411939740181
12 / 100 |  label num : 13  | loss : 0.0680113360285759
13 / 100 |  label num : 13  | loss : 0.05786644667387009
14 / 100 |  label num : 13  | loss : 0.045792996883392334
15 / 100 |  label num : 12  | loss : 0.042237892746925354
16 / 100 |  label num : 12  | loss : 0.03820480778813362
17 / 100 |  label num : 12  | loss : 0.04540208727121353
18 / 100 |  label num : 11  | loss : 0.0393

In [None]:
def viz(src, dst):
    img = cv2.imread(src)
    img = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_AREA)

    label_colours = np.random.randint(255,size=(512,3))

    img = torch.from_numpy(img.transpose((2, 0, 1)).astype("float32")) / 255.

    img = img.unsqueeze(0)
    img = img.to(device=device)
    output = model( img )

    output = output.data.squeeze(0)
    img = img.squeeze(0)
    img = img.permute((2, 1, 0))
    
    output = output.permute( 1, 2, 0 ).contiguous().view( -1, 512 )

    print(output.shape)

    _, target = torch.max( output, 1 )
    im_target = target.data.cpu().numpy()

    print(output.data)
 
    im_target_rgb = np.array([label_colours[ c % 512 ] for c in im_target])

    for c in im_target:
      print(c)
      break

    # im_target_rgb = im_target_rgb.reshape( img.shape ).astype( np.uint8 )

    
    # cv2.imwrite(dst, im_target_rgb)

In [None]:
src = "/content/drive/MyDrive/Data-20220522T105527Z-002/Rye/scout_point_image_20220318T145503000Z.jpe"
dst = "/content/output.jpeg"

In [None]:
viz(src, dst)

torch.Size([49, 512])
tensor([[0.4341, 0.4522, 0.4872,  ..., 0.3688, 6.8607, 0.4137],
        [0.4177, 0.4575, 0.4587,  ..., 0.4750, 6.3368, 0.3434],
        [0.4718, 0.5475, 0.5040,  ..., 0.5492, 0.0000, 0.5337],
        ...,
        [0.3869, 0.4523, 0.4652,  ..., 0.4530, 0.0000, 0.1902],
        [0.4135, 0.5311, 0.4793,  ..., 0.5355, 0.0000, 0.3794],
        [0.5319, 0.5521, 0.4868,  ..., 0.5390, 0.0000, 0.4583]],
       device='cuda:0')
274
