In [None]:
!pip install torchvision

In [None]:
!pip install torchinfo

In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

In [None]:
import copy
import cv2
import os
import random

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset

from torchinfo import summary

import torchvision
import torchvision.transforms as transforms
import torchvision.utils

from zipfile import ZipFile

from transformers import AutoImageProcessor, Swinv2Model

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
use_pre_trained = False # True - load the model already trained and saved, False - train the model
"""
When save_input_classes = True and save_train_test_splitting = True,
the initial splitting of the data into training and test sets is performed.
"""
save_input_classes = True
save_train_test_splitting = True

In [None]:
DATA_PATH = '/content/gdrive/MyDrive/'
CURR_PATH = '/content/gdrive/MyDrive/Colab Notebooks/'

FILE_NAME = CURR_PATH + 'model_seg_pytorch_medical.cnn'   # to save the model check-points during the training
FILE_NAME_PR = CURR_PATH + 'model_seg_pytorch_medical.cnn_23' # the pre-trained model

r_size = 256
batch_size = 16
ep_num = 22

zip_name = 'ISSBI2015.zip'

In [None]:
with ZipFile(DATA_PATH + zip_name, 'r') as f:
    names = f.namelist()

len(names), names[10]

In [None]:
names = [n.split('/')[-1] for n in names]
names = list(filter(lambda x: (x != '' and x != 'data.csv' and x != 'README.md'), names))
names[0], len(names)

In [None]:
f_names = [ 'ISSBI2015/' + n[:8]+n[9] + '/' + n for n in names]
f_names[0]

In [None]:
f_images = list(filter(lambda x: 'mask' not in x, f_names))
f_images[0], len(f_images)

In [None]:
f_masks = list(filter(lambda x: 'mask' in x, f_names))
f_masks[0], len(f_masks)

In [None]:
del f_names
del f_masks

In [None]:
def show_input_sample(f_images):
    random.shuffle(f_images)
    with ZipFile(DATA_PATH + zip_name) as archive:
        data = archive.read(f_images[0])
        print(f_images[0][:-5] + '+mask.tiff')
        data1 = archive.read(f_images[0][:-5] + '+mask.tiff')

    img = cv2.imdecode(np.frombuffer(data, np.uint8), 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    mask = cv2.imdecode(np.frombuffer(data1, np.uint8), 1)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    print(mask.max(), mask.min(), np.unique(mask))

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 10))
    axes[0].imshow(img)
    axes[1].imshow(mask)
    plt.show()

In [None]:
show_input_sample(f_images)

In [None]:
if save_input_classes:
    f_images0 = []
    f_images1 = []

    for f in f_images:
        with ZipFile(DATA_PATH + zip_name) as archive:
            data = archive.read(f)
            print(f[:-5] + '+mask.tiff')
            data1 = archive.read(f[:-5] + '+mask.tiff')

            mask = cv2.imdecode(np.frombuffer(data1, np.uint8), 1)
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
      #      print(mask.max(),mask.min(),np.unique(mask))
            if mask.max() < 127:
                f_images0.append(f)
            else:
                f_images1.append(f)

    with open(DATA_PATH + "fN_im0.txt", "w") as fl:
        for f in f_images0:
            fl.write(f + '\n')

    with open(DATA_PATH + "fN_im1.txt", "w") as fl:
        for f in f_images1:
            fl.write(f + '\n')

In [None]:
with open(DATA_PATH + "fN_im0.txt", "r") as fl:
    f_images0 = fl.readlines()

f_images0 = [f.strip() for f in f_images0]

with open(DATA_PATH + "fN_im1.txt", "r") as fl:
    f_images1 = fl.readlines()

f_images1 = [f.strip() for f in f_images1]

In [None]:
len(f_images0), len(f_images1)

In [None]:
show_input_sample(f_images1)

In [None]:
show_input_sample(f_images0)

In [None]:
if save_train_test_splitting:
    random.shuffle(f_images0)
    train_len0 = int(0.9 * len(f_images0))
    df_train0 = f_images0[:train_len0]
    df_test0 = f_images0[train_len0:]

    random.shuffle(f_images1)
    train_len1 = int(0.9 * len(f_images1))
    df_train1 = f_images1[:train_len1]
    df_test1 = f_images1[train_len1:]

    #To make a balanced training dataset:

    diff = len(df_train0) - len(df_train1)
    random.shuffle(df_train1)
    df_train1.extend(df_train1[:diff])

    df_train = df_train0
    df_test = df_test0

    df_train.extend(df_train1)
    df_test.extend(df_test1)

    with open(DATA_PATH + 'dfN_train.txt', 'w') as fl:
        for f in df_train0:
            fl.write(f + '\n')

    with open(DATA_PATH + 'dfN_test.txt', 'w') as fl:
        for f in df_test0:
            fl.write(f + '\n')

In [None]:
with open(DATA_PATH + 'dfN_train.txt', 'r') as fl:
    df_train = fl.readlines()
df_train = [f.strip() for f in df_train]

with open(DATA_PATH + 'dfN_test.txt', 'r') as fl:
    df_test = fl.readlines()
df_test = [f.strip() for f in df_test]

print(len(df_train))
print(len(df_test))

In [None]:
# just to verify that the test data is not included in the training set
for dl in df_test:
    if dl in df_train:
        print('test line in the train set!!!')

In [None]:
from transformers import AutoImageProcessor, Swinv2Config, Swinv2Model

# مرحله اول: لود کانفیگ مدل
config = Swinv2Config.from_pretrained("microsoft/swinv2-large-patch4-window12-192-22k")

# مرحله دوم: ساخت مدل بدون بارگذاری وزن‌های آموزش‌دیده
model_seg = Swinv2Model(config).to(device)

# مرحله سوم: لود پردازشگر ورودی مثل قبل (می‌تونی نگه داری چون مربوط به ورودی‌هاست نه وزن مدل)
image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-large-patch4-window12-192-22k")

In [None]:
#image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-large-patch4-window12-192-22k")
#model_seg = Swinv2Model.from_pretrained("microsoft/swinv2-large-patch4-window12-192-22k").to(device)

In [None]:
class SegmentDataset(Dataset):
    def __init__(self,
                 gen_df,
                 transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize((r_size, r_size), interpolation=transforms.InterpolationMode.NEAREST_EXACT)
                 ]),
                 mask_color='GRAY'):
        self.gen_df = gen_df
        self.transform = transform
        self.mask_color = mask_color

    def __getitem__(self, index):
        img_name = self.gen_df[index]

        with ZipFile(DATA_PATH + zip_name) as archive:
            data = archive.read(img_name)
            data1 = archive.read(img_name[:-5] + '+mask.tiff')

        img = cv2.imdecode(np.frombuffer(data, np.uint8), 1)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imdecode(np.frombuffer(data1, np.uint8), 1)
        if self.mask_color == 'GRAY':
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        else:
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        img_in = self.transform(img)
        mask = self.transform(mask)

        img = Image.fromarray(img)
        img = image_processor(images=img, return_tensors="pt")

        x = model_seg.embeddings(**img.to(device))
        input_dimensions=x[1]
        img0 = x[0].detach().squeeze(0)

        x = model_seg.encoder.layers[0](x[0], input_dimensions=input_dimensions)
        img1 = x[0].detach().squeeze(0)

        x = model_seg.encoder.layers[1](x[0], input_dimensions=(input_dimensions[0]//2, input_dimensions[1]//2) )
        img2 = x[0].detach().squeeze(0)

        x = model_seg.encoder.layers[2](x[0], input_dimensions=(input_dimensions[0]//4, input_dimensions[1]//4) )
        img3 = x[0].detach().squeeze(0)

        x = model_seg.encoder.layers[3](x[0], input_dimensions=(input_dimensions[0]//8, input_dimensions[1]//8) )
        x = model_seg.layernorm(x[0])
        img4 = x.detach().squeeze(0)

        return img0, img1, img2, img3, img4, mask, img_in

    def __len__(self):
        return len(self.gen_df)

In [None]:
try_dataset = SegmentDataset(df_train, mask_color='RGB')

try_dataloader = DataLoader(try_dataset,
                        shuffle=True,
                        num_workers=0,
                        batch_size=8)
dataiter = iter(try_dataloader)

In [None]:
i0, i1, i2, i3, i4, y, x1 = next(dataiter)
print(i0.shape, i1.shape, i2.shape, i3.shape, i4.shape)
print(y.shape, x1.shape)

concatenated = torch.cat((x1, y),0)
c_img = torchvision.utils.make_grid(concatenated).permute(1, 2, 0)
plt.axis("off")
plt.imshow(c_img)

In [None]:
del try_dataloader
del try_dataset
del concatenated
del c_img

In [None]:
train_dataset = SegmentDataset(df_train, mask_color='GRAY')

In [None]:
class Up_Linear(nn.Module):
    def __init__(self, in_ch, size, coef=1):
        super(Up_Linear, self).__init__()
        self.shuffle = nn.PixelShuffle(upscale_factor=2)

        n_ch = int(coef * in_ch)

        self.ln = nn.Sequential(
            nn.Linear(in_ch * 2, n_ch),
            nn.ReLU(inplace=True),
            nn.Linear(n_ch, in_ch * 2),
            nn.ReLU(inplace=True),
        )

        self.size = size

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), 2)
        x = self.ln(x)
        x = x.permute(0, 2, 1)
        x = torch.reshape(x, (x.shape[0], x.shape[1], self.size, self.size))
        x = self.shuffle(x)
        x = torch.reshape(x, (x.shape[0], x.shape[1], self.size*self.size*4))
        x = x.permute(0, 2, 1)
        return x

class MRI_Seg(nn.Module):
    def __init__(self):
        super(MRI_Seg, self).__init__()

        self.ups3 = Up_Linear(1536, 6, 1)
        self.ups2 = Up_Linear(768, 12, 1)
        self.ups1 = Up_Linear(384, 24, 2)
        self.ups0 = Up_Linear(192, 48, 3)

        self.shuffle = nn.PixelShuffle(upscale_factor=2)

        self.out = nn.Sequential(
            nn.Conv2d(24, 1, kernel_size=1, stride=1),
            nn.Sigmoid()
        )

    def forward(self, x0, x1, x2, x3, x4):
        x = self.ups3(x4, x3)
        x = self.ups2(x, x2)
        x = self.ups1(x, x1)
        x = self.ups0(x, x0)

        x = x.permute(0, 2, 1)
        x = torch.reshape(x, (x.shape[0], x.shape[1], 96, 96))
        x = self.shuffle(x)
        x = transforms.Resize((256, 256))(x)

        x = self.out(x)
        return x


In [None]:
train_dataloader = DataLoader(train_dataset,
                        shuffle=True,
                        num_workers=0,
                        batch_size=batch_size)

net = MRI_Seg().to(device)

criterion = nn.BCELoss()
lr = 0.0001
optimizer = optim.Adam(net.parameters(), lr=lr)

In [None]:
len(train_dataloader)

In [None]:
dataiter = iter(train_dataloader)
i0, i1, i2, i3, i4, y, x1 = next(dataiter)

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 10))

axes[0].imshow(y[0].permute(1, 2, 0))
axes[1].imshow(x1[0].permute(1, 2, 0))
plt.show()

In [None]:
summary(model=net, input_size=[(1, 2304, 192), (1, 576, 384), (1, 144, 768), (1, 36, 1536), (1, 36, 1536)], col_names=['input_size', 'output_size', 'num_params', 'trainable'])

In [None]:
net.eval().to(device)

In [None]:
out = net(Variable(i0).to(device), Variable(i1).to(device), Variable(i2).to(device), Variable(i3).to(device), Variable(i4).to(device))
print(out.shape, y[0].shape)
print(y[0].max())

ls = criterion(out[0], Variable(y[0]).to(device))

print(ls)
plt.imshow(out[0].cpu().detach().numpy()[0])
del out

In [None]:
def train_net():
    num_iter = len(train_dataloader)
    ep_init = 0

    for epoch in range(ep_init, ep_num):
        sum_loss = 0
        print("Epoch number: {}".format(epoch))
        for i, data in enumerate(train_dataloader, 0):

            img0, img1, img2, img3, img4, mask, img_in = data

            optimizer.zero_grad()

            x0 = Variable(img0).to(device)
            x1 = Variable(img1).to(device)
            x2 = Variable(img2).to(device)
            x3 = Variable(img3).to(device)
            x4 = Variable(img4).to(device)

            output = net(x0, x1, x2, x3, x4)
            # label.squeeze() - to transform lables to 1-dim vector
            # without squeeze() loss calculation is incorrect
            loss_bce = criterion(output, Variable(mask).to(device))
            loss_bce.backward()
            optimizer.step()

            sum_loss += loss_bce.data
            if i % 10 == 0:
                print('{} ===================  {}'.format(i, sum_loss/(i + 1)))

        print("Epoch number: {}, Num iter: {}, lr: {}, Current loss: {}".format(epoch, num_iter, optimizer.param_groups[0]['lr'], sum_loss/num_iter))
        torch.save(net, FILE_NAME + '_{}'.format(epoch + 1))

    torch.save(net, FILE_NAME)
    print("The pre-trained model saved")

In [None]:
if use_pre_trained is False:
    net.train().to(device)
    train_net()

In [None]:
def calc_rect(img_mask):
    ind = np.argwhere(img_mask >= 0.5)
    if len(ind) == 0:
        return None, None
    top_y = min(ind[:,0])
    bottom_y = max(ind[:,0])
    top_x = min(ind[:,1])
    bottom_x = max(ind[:,1])
    return (top_x, top_y), (bottom_x, bottom_y)

def show_results(i0, i1, i2, i3, i4, y, x1, im_id):
    ii0 = Variable(i0).to(device)
    ii1 = Variable(i1).to(device)
    ii2 = Variable(i2).to(device)
    ii3 = Variable(i3).to(device)
    ii4 = Variable(i4).to(device)
    pred = net(ii0, ii1, ii2, ii3, ii4)
    pr = pred[im_id].cpu().detach().numpy()[0]

    xim = copy.deepcopy(x1[im_id].permute(1, 2, 0).cpu().detach().numpy())
    # just to transform numpy array to cv2 image:
    xim = cv2.resize(xim, (r_size, r_size))
    top_left, bottom_right = calc_rect(pr)
    if top_left is not None:
        cv2.rectangle(xim, top_left, bottom_right, (255, 0, 0), 2)

    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 15))
    axes[0].imshow(y[im_id].cpu().detach().numpy()[0])
    axes[1].imshow(pr > 0.5)
    axes[2].imshow(xim)

    plt.show()

In [None]:
if use_pre_trained:
    del net
    net = torch.load(FILE_NAME_PR)
    print("The pre-trained model loaded")

In [None]:
net.eval().to(device)

In [None]:
show_results(i0, i1, i2, i3, i4, y, x1, 2)

In [None]:
test_dataloader1 = DataLoader(train_dataset,
                        shuffle=True,
                        num_workers=0,
                        batch_size=1)

test_dataset = SegmentDataset(df_test)
test_dataloader2 = DataLoader(test_dataset,
                        shuffle=False,
                        num_workers=0,
                        batch_size=1)

In [None]:
dataiter = iter(test_dataloader1)

In [None]:
i0, i1, i2, i3, i4, y, x1 = next(dataiter)
show_results(i0, i1, i2, i3, i4, y, x1, 0)

In [None]:
def calc_accuracy(test_dataloader, set_id, model, sample_num=None):
    batch_size = 1
    if not sample_num:
        N = len(test_dataloader)
    else:
        if sample_num <= 0:
            sample_num = len(test_dataloader)
        N = min(sample_num, len(test_dataloader))

    And = 0
    Uni = 0
    Uni_dice = 0

    T0 = 0
    T1 = 0
    F0 = 0
    F1 = 0

    for i, data in enumerate(test_dataloader, 0):
        #xx, yy, xs = data
        img0, img1, img2, img3, img4, yy, xs = data

        x0 = Variable(img0).to(device)
        x1 = Variable(img1).to(device)
        x2 = Variable(img2).to(device)
        x3 = Variable(img3).to(device)
        x4 = Variable(img4).to(device)

        xx1 = model(x0, x1, x2, x3, x4)
        xx1 = xx1[0][0].cpu().detach().numpy()
        yy = yy[0][0].cpu().detach().numpy()
        xx1[xx1 >= 0.5 ] = 1
        xx1[xx1 < 0.5 ] = 0

        owl = np.sum(xx1*yy)
        And += owl
        a_uni_dice = np.sum(xx1 + yy)
        a_uni = a_uni_dice - owl
        Uni += a_uni
        Uni_dice += a_uni_dice

        if xx1.max() == 0 and yy.max() == 0:
            T0 += 1
        if xx1.max() == 1 and yy.max() == 1:
            T1 += 1
        if xx1.max() == 0 and yy.max() == 1:
            F0 += 1
        if xx1.max() == 1 and yy.max() == 0:
            F1 += 1

        print('{}:  i = {}, And = {}, Uni = {}'.format(set_id, i, And, Uni))

        if i >= N - 1:
            break

    IoU_av = And / Uni
    Dice = 2*And / Uni_dice

    return  IoU_av, Dice, T0, T1, F0, F1


In [None]:
IoU_tr, Dice_tr, T0_tr, T1_tr, F0_tr, F1_tr = calc_accuracy(test_dataloader1, 'train', net, sample_num=550)

In [None]:
print('training set: IoU = {}, Dice = {}, True_0 = {}, True_1 = {}, False_0 = {}, False_1 = {}'.format(IoU_tr, Dice_tr, T0_tr, T1_tr, F0_tr, F1_tr))

In [None]:
IoU_ts, Dice_ts, T0_ts, T1_ts, F0_ts, F1_ts = calc_accuracy(test_dataloader2, 'test', net)

In [None]:
print('test set: IoU = {}, Dice = {}, True_0 = {}, True_1 = {}, False_0 = {}, False_1 = {}'.format(IoU_ts, Dice_ts, T0_ts, T1_ts, F0_ts, F1_ts))