In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from sklearn.model_selection import StratifiedKFold,KFold

img_row = 224
img_col = 224
batch_size=18

subject='hiaa2'
main_path=os.path.join("E:\\kaggle_imgs",subject)
img_path=os.path.join(main_path,"img")
data_path=os.path.join(main_path,"data")
saved_path=os.path.join(main_path,"saved_models")
paths=[main_path, img_path,saved_path,data_path]
for fp in paths:
    print(fp)
    if not os.path.exists(fp):        
        os.mkdir(fp)
file_path=os.path.join(saved_path,"epoch_12_loss_0.003064.pth")
file_best=os.path.join(saved_path,"epoch_12_loss_0.003064.pth")

train_img_pkl=os.path.join(data_path,"train.csv")
test_img_pkl=os.path.join(data_path,"test_imgs.npy")
train_info_pkl=os.path.join(data_path,"df_train_pickle.csv")

num_classes=4

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

## Make data file

In [None]:
def read_train_info():
    fp=train_img_pkl
    if False and os.path.exists(fp):
        try:
            print('loading train data from csv', flush=True)
            df_train=pd.read_csv(fp)
            print('complete!', flush=True)
        except EOFError:
            print('EOFError raised.', flush=True)
    else:
        files=os.listdir(os.path.join(img_path,"train"))
        df_train=pd.DataFrame({"image_name":files})
        df_train=df_train[~df_train.image_name.str.contains('mask')].reset_index(drop=True)
        df_train["id"]=df_train["image_name"].apply(lambda x : int(x.split('.')[0]))
        df_train["mask_name"]=df_train["id"].apply(lambda x : str(x)+"_mask.png")
        kf=KFold(n_splits=5,random_state=22)        
        df_train["fold"]=-1
        X=df_train.id
        for i,(train_idx,valid_idx) in enumerate(kf.split(X)):
            df_train.loc[valid_idx,"fold"]=i
        df_train.to_csv(fp,index=False)
        
    return df_train
    
df_train=read_train_info()    

## Augmentation

In [None]:
from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, RandomBrightnessContrast, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, Flip, OneOf, Compose, Rotate, Cutout, VerticalFlip, Normalize
)
from albumentations.pytorch import ToTensor

from torchvision import transforms


In [None]:
train_transforms = transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
valid_transforms =  transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
test_transforms  = transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
mask_transforms  = transforms.Compose([
                                 transforms.ToTensor()])


In [None]:
# train_transforms= Compose([
# #     Rotate(15),
# #     OneOf([
# #         IAAAdditiveGaussianNoise(),
# #         GaussNoise(),
# #     ], p=0.2),
# #     OneOf([
# #         MotionBlur(p=0.2),
# #         MedianBlur(blur_limit=3, p=0.1),
# #         Blur(blur_limit=3, p=0.1),
# #     ], p=0.2),
# #     ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
# #     OneOf([
# #         OpticalDistortion(p=0.3),
# #         GridDistortion(p=0.1),
# #         IAAPiecewiseAffine(p=0.3),
# #     ], p=0.2),
# #     OneOf([
# #         CLAHE(clip_limit=2),
# #         IAASharpen(),
# #         IAAEmboss(),
# #         RandomBrightnessContrast(),
# #     ], p=0.3),
# #     HueSaturationValue(p=0.3),
# #     Normalize(),
#     ToTensor()
# ])
# valid_transforms=Compose([
# #     Normalize(),
#     ToTensor()
# ])

## Dataset

In [None]:
#Test
from PIL import Image
img = Image.open(img_path+"/train/11000.jpg")
img=np.uint8(img)
img = train_transforms(img)
plt.imshow(img.permute(1,2,0))
print(img)

In [None]:
from torch.utils.data import Dataset
from PIL import Image

def rgb2gray(rgb):
    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray

class PlantDataset(Dataset):
    def __init__(self, df, tr=None,subset="train"):
        self.df = df
        self.tr=tr
        self.subset=subset
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        image_name=self.df.iloc[idx].image_name
        if self.subset=="test":
            img = Image.open(img_path+"/test/"+image_name)
            img=np.uint8(img)
            img = self.tr(img)
        else:
            img = Image.open(img_path+"/train/"+image_name)
            img=np.uint8(img)
            img = self.tr(img)
            
        mask=np.zeros_like(img)
        if self.subset=="train":
            mask_name=self.df.iloc[idx].mask_name
            mask = Image.open(img_path+"/train/"+mask_name)
            mask =np.uint8(mask)
            mask=mask_transforms(mask)

        return img, mask

## transforms and normalization

In [None]:
def get_images_by_fold(fold):
    sel=fold
    trn_fold=[i for i in range(5) if i not in [sel]]
    val_fold=[i for i in range(5) if i in [sel]]
    trn_idx=df_train[df_train.fold.isin(trn_fold)].index
    val_idx=df_train[df_train.fold.isin(val_fold)].index
    trainset = PlantDataset(df=df_train.loc[trn_idx],
                           tr=train_transforms)
    validset =PlantDataset(df=df_train.loc[val_idx],
                          tr=valid_transforms)

    train_loader = torch.utils.data.DataLoader(trainset,
                                              batch_size=batch_size,
                                              shuffle=True, num_workers=0)
    valid_loader = torch.utils.data.DataLoader(validset,
                                              batch_size=batch_size,
                                              shuffle=False, num_workers=0)
    return train_loader,valid_loader

In [None]:
train_loader,valid_loader=get_images_by_fold(0)

In [None]:
for img,mask in train_loader:
    break

In [None]:
mask[0].size(),img[0].size()


In [None]:
for i in range(batch_size):
    f,ax=plt.subplots(1,2,figsize=(14,5))
    ax[0].imshow(img[i].permute(1,2,0))
    ax[1].imshow(mask[i].permute(1,2,0).squeeze())
    plt.show()

In [None]:
mask[i].shape

## Model

In [None]:
import torch
import torch.nn as nn
from torchvision import models

def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetUNet(n_class=6)
model = model.to(device)

# check keras-like model summary using torchsummary
from torchsummary import summary
summary(model, input_size=(3, 224,224))

## define the main training loop

In [None]:
from collections import defaultdict
import torch.nn.functional as F
from loss import dice_loss

def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)

    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss

def print_metrics(metrics, epoch_samples, phase):
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))

    print("{}: {}".format(phase, ", ".join(outputs)))

def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])

                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            my_loader= train_loader if phase=="train" else valid_loader
            for inputs, labels in my_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                }
                file_best=saved_path+"/epoch_%2d_loss_%.6f.pth"%(epoch, best_loss)
                torch.save(state, file_best)

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

## Training

In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 1
model = ResNetUNet(num_class).to(device)

# freeze backbone layers
#for l in model.base_layers:
#    for param in l.parameters():
#        param.requires_grad = False

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)


In [None]:
## Load Best
 if os.path.isfile(file_best):
    checkpoint = torch.load(file_best)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("Load Complete",cur_best)

In [None]:

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=60)

## Use the trained model

In [None]:
import math
model.eval()

# get the first batch
for i,(inputs,labels) in enumerate(valid_loader):
    inputs=inputs.to(device)
    labels=labels.to(device)

    #predict
    pred=model(inputs)
    #The loss fuctions include the sigmoid function.
    pred=F.sigmoid(pred)
    pred=pred.data.cpu().numpy()
    print(pred.shape)

    inputs=inputs.to("cpu")
    labels=labels.to("cpu")
    if i==2:
        break


In [None]:
for i in range(batch_size):
    f,ax=plt.subplots(1,3,figsize=(15,7))
    ax[0].imshow(inputs[i].permute(1,2,0)*255)
    ax[1].imshow(labels[i].permute(1,2,0).squeeze())
    img1=np.uint8(pred[i].squeeze()>0.001)
    #ax[2].imshow(inputs[i].permute(1,2,0)*255)
    ax[2].imshow(img1, cmap="jet",alpha=0.5)
    plt.show()

## Test

In [None]:
files=os.listdir(os.path.join(img_path,"test"))
df_test=pd.DataFrame({"image_name":files})


In [None]:
testset = PlantDataset(df=df_test,
                   tr=test_transforms,subset="test")
test_loader = torch.utils.data.DataLoader(testset,
                                      batch_size=batch_size,
                                      shuffle=False, num_workers=0)

In [None]:
for img,mask in test_loader:
    break

In [None]:
plt.imshow(img[0].permute(1,2,0))

In [None]:
import math
model.eval()

# get the first batch
for i,(inputs,labels) in enumerate(test_loader):
    inputs=inputs.to(device)
    labels=labels.to(device)

    #predict
    pred=model(inputs)
    #The loss fuctions include the sigmoid function.
    pred=F.sigmoid(pred)
    pred=pred.data.cpu().numpy()
    print(pred.shape)

    inputs=inputs.to("cpu")
    labels=labels.to("cpu")
    if i==1:
        break

In [None]:
for i in range(batch_size):
    f,ax=plt.subplots(1,3,figsize=(15,7))
    ax[0].imshow(inputs[i].permute(1,2,0))
    ax[1].imshow(labels[i].permute(1,2,0).squeeze())
    #img1=np.uint8(pred[i].squeeze()>0.001)
    #ax[2].imshow(inputs[i].permute(1,2,0)*255)
    ax[2].imshow(pred[i].squeeze(), cmap="jet",alpha=0.5)
    plt.show()

In [None]:
pred[i].squeeze().min(),pred[i].squeeze().min(),pred[i].squeeze().sum()

In [None]:
inputs[i]