In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import cv2
import math
import os
import seaborn as sns
from albumentations import Compose, OneOf, Flip, Rotate, RandomGamma, ElasticTransform, GridDistortion, OpticalDistortion

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Shuffle
import random

random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
# torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

def seed_worker(worker_id):
    np.random.seed(random_seed + worker_id)
    random.seed(random_seed + worker_id)




# 데이터셋

In [5]:
#DeepCrack
root_dir = 'Dataset/DeepCrack' if True else 'Dataset'
train_image_dir = f'{root_dir}/train_img'
train_mask_dir = f'{root_dir}/train_lab'

test_image_dir = f'{root_dir}/test_img'
test_mask_dir = f'{root_dir}/test_lab'

In [None]:
train_image_paths = sorted([os.path.join(train_image_dir, fname) for fname in os.listdir(train_image_dir) if fname.endswith(".jpg") and not fname.startswith(".")])
train_mask_paths = sorted([os.path.join(train_mask_dir, fname) for fname in os.listdir(train_mask_dir) if fname.endswith(".png") and not fname.startswith(".")])
print("Number of train images : ", len(train_image_paths))
print("Number of train masks : ", len(train_mask_paths))

print()

test_image_paths = sorted([os.path.join(test_image_dir, fname) for fname in os.listdir(test_image_dir) if fname.endswith(".jpg") and not fname.startswith(".")])
test_mask_paths = sorted([os.path.join(test_mask_dir, fname) for fname in os.listdir(test_mask_dir) if fname.endswith(".png") and not fname.startswith(".")])
print("Number of testing images : ", len(test_image_paths))
print("Number of testing masks : ", len(test_mask_paths))

In [None]:
# Split train, valid
train_image_files, train_mask_files = train_image_paths, train_mask_paths
test_image_files, test_mask_files = test_image_paths, test_mask_paths

print(len(train_image_files), len(train_mask_files))
print(len(test_image_files), len(test_mask_files))

img_dim=(256,256)

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
from albumentations import Compose, OneOf, Flip, Rotate,  RandomBrightnessContrast , RandomGamma, ElasticTransform, GridDistortion, OpticalDistortion, RGBShift, CLAHE
import albumentations

class Generator(Dataset):
    def __init__(self, x_set, y_set, augment=False):
        self.x = x_set
        self.y = y_set 
        self.img_dim = img_dim
        self.augment = augment

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(self.img_dim),
            transforms.ToTensor(),
        ])

        # self.augmentations = Compose([
        #     Flip(p=0.7),
        #     Rotate(p=0.7),
        #     OneOf([
        #         RandomBrightnessContrast(),
        #         RandomGamma()
        #     ], p=0.3),
        #     OneOf([
        #         ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        #         GridDistortion(),
        #         OpticalDistortion(distort_limit=2, shift_limit=0.5)
        #     ], p=0.3),
        # ])
        
        self.augmentations = Compose([
            Flip(p=0.7),
            Rotate(limit=90, p=0.7),
            OneOf([
                albumentations.HorizontalFlip(p=1),
                albumentations.RandomRotate90(p=1),
                albumentations.VerticalFlip(p=1)  
            ], p=0.7),
            RandomBrightnessContrast(p=0.3),
            RandomGamma(gamma_limit=(80, 120), p=0.5)

        ])


    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        img_x = np.array([(cv2.cvtColor(cv2.imread(self.x[idx]), cv2.COLOR_BGR2RGB))])[0]
        img_y = np.array([(cv2.imread(self.y[idx], cv2.IMREAD_GRAYSCALE))])[0]

        img_x = self.transform(img_x)
        img_y = self.transform(img_y)

        if self.augment:
            img_x_np = img_x.permute(1, 2, 0).numpy()
            img_y_np = img_y.permute(1, 2, 0).numpy()

            augmented = self.augmentations(image=img_x_np, mask=img_y_np)
            img_x = torch.from_numpy(augmented["image"]).permute(2, 0, 1)
            img_y = torch.from_numpy(augmented["mask"]).permute(2, 0, 1)

        img_y = img_y > 0
        #print(img_y)
        #print(img_x)
        return img_x*255, img_y

In [None]:
# Define batch size and number of workers
batch_size = 3
num_workers = 0

# Create an instance of the custom dataset
train_dataset = Generator(train_image_files, train_mask_files, augment=True)
test_dataset = Generator(test_image_files, test_mask_files, augment=False)

# Create a data loader

train_data_loader = DataLoader(train_dataset, batch_size, shuffle=True, worker_init_fn=seed_worker) 
test_data_loader = DataLoader(test_dataset, batch_size, shuffle=False, worker_init_fn=seed_worker) 

In [None]:
for i, j in train_data_loader:
    break

print(i.shape)

i = i.permute(0, 2,3,1)
j = j.permute(0, 2,3,1)

print(i.shape)
print(j.shape)

fig, axes = plt.subplots(1, 5, figsize=(13,2.5))
fig.suptitle('Original Images', fontsize=15)
axes = axes.flatten()

for img, ax in zip(i[:5], axes[:5]):
    #print(img.shape)
    ax.imshow(img/255)
    ax.axis('off')
plt.tight_layout()
plt.show()

fig, axes = plt.subplots(1, 5, figsize=(13,2.5))
fig.suptitle('Original Masks', fontsize=15)
axes = axes.flatten()
for img, ax in zip(j[:5], axes[:5]):
    #print(img.shape)
    ax.imshow(np.squeeze(img, -1), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()

# 모델

In [None]:
from functools import partial

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_

class DistilledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.conv = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size=7, stride=2, padding=4),
            nn.BatchNorm2d(3),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()

        trunc_normal_(self.dist_token, std=.02)
        trunc_normal_(self.pos_embed, std=.02)
        self.head_dist.apply(self._init_weights)

    def forward_features(self, x):
        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
        # with slight modifications to add the dist_token
        B = x.shape[0]
        x = self.conv(x)
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:, 0], x[:, 1]

    def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        if self.training:
            return x, x_dist
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2


@register_model
def deit_tiny_patch16_256(pretrained=True,**kwargs):
    model = VisionTransformer(
        img_size=256, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    if pretrained:
        ckpt = torch.load('pretrained/deit_tiny_patch16_256.pth')
        model.load_state_dict(ckpt['model'], strict=False)
    model.default_cfg = _cfg()
    return model


def deit_tiny_distilled_patch16_256(pretrained=True,**kwargs):
    model = DistilledVisionTransformer(
        img_size=256, patch_size=16, embed_dim=256, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    if pretrained:
        ckpt = torch.load('pretrained/deit_tiny_distilled_patch16_256.pth')
        model.load_state_dict(ckpt['model'], strict=False)
    model.default_cfg = _cfg()
    return model


## Specifically, the images in the ImageNet 2012 dataset ( Russakovsky et al., 2015 ) were resized to 256 × 256 , which were divided into 256 patches with the resolution of 16 × 16 . 

In [12]:
from torchvision import models

class resnet34(torch.nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        self.features = models.resnet34(pretrained=pretrained)
        self.conv1 = self.features.conv1
        self.bn1 = self.features.bn1
        self.relu = self.features.relu
        self.layer1 = self.features.layer1
        self.layer2 = self.features.layer2
        self.layer3 = self.features.layer3
        self.layer4 = self.features.layer4

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x0 = self.relu(x)
        feature1 = self.layer1(x0)  # 1 / 4
        feature2 = self.layer2(feature1)  # 1 / 8
        feature3 = self.layer3(feature2)  # 1 / 16
        feature4 = self.layer4(feature3)  # 1 / 32

        return x0, feature1, feature2, feature3, feature4

class PyramidPoolingModule(nn.Module):
    def __init__(self, pyramids=[1, 2, 3, 6]):
        super(PyramidPoolingModule, self).__init__()
        self.pyramids = pyramids

    def forward(self, input):
        feat = input
        height, width = input.shape[2:]
        for bin_size in self.pyramids:
            x = F.adaptive_avg_pool2d(input, output_size=bin_size)
            x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True)
            feat = feat + x
        return feat

class SEBlock(nn.Module):
    def __init__(self, channel, r=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // r, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        # Squeeze
        y = self.avg_pool(x).view(b, c)
        # Excitation
        y = self.fc(y).view(b, c, 1, 1)
        # Fusion
        y = torch.mul(x, y)
        return y

class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ResBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(True),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channel)
        )
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 1, 1, bias=False),
            nn.BatchNorm2d(out_channel)
        )

    def forward(self, x):
        out = self.left(x)
        residual = self.shortcut(x)
        out += residual
        return F.relu(out)

class FeatureFusion(nn.Module):
    """CFF Unit"""

    def __init__(self, in_channel, out_channel):
        super(FeatureFusion, self).__init__()
        self.fusion = ResBlock(in_channel, out_channel)

    def forward(self, x_high, x_low):
        x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_low, x_high), dim=1)
        x = self.fusion(x)

        return x

class RPMBlock(nn.Module):
    def __init__(self, channels):
        super(RPMBlock, self).__init__()

        self.conv3 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1)

        self.relu3 = nn.ReLU(inplace=True)
        self.relu1 = nn.ReLU(inplace=True)

    def forward(self, x):
        x3 = self.conv3(x)
        x3 = self.relu3(x3)
        x1 = self.conv1(x)
        x1 = self.relu1(x1)
        out = x3 + x1
        return out

class DecoderBottleneckLayer(nn.Module):
    def __init__(self, in_channels, n_filters, use_transpose=True):
        super(DecoderBottleneckLayer, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
        self.norm1 = nn.BatchNorm2d(in_channels // 4)
        self.relu1 = nn.ReLU(inplace=True)

        if use_transpose:
            self.up = nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels // 4, in_channels // 4, 3, stride=2, padding=1, output_padding=1
                ),
                nn.BatchNorm2d(in_channels // 4),
                nn.ReLU(inplace=True)
            )
        else:
            self.up = nn.Upsample(scale_factor=2, align_corners=True, mode="bilinear")

        self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
        self.norm3 = nn.BatchNorm2d(n_filters)
        self.relu3 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        x = self.up(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu3(x)
        return x

In [13]:
class PCTCNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super(PCTCNet, self).__init__()
        self.n_class = n_classes
        self.inchannels = n_channels
        size = 256

        self.cnn = resnet34(pretrained=False)
        self.headpool = PyramidPoolingModule()

        transformer = deit_tiny_distilled_patch16_256(pretrained=False)
        self.patch_embed = transformer.patch_embed
        self.transformers = nn.ModuleList(
            [transformer.blocks[i] for i in range(12)]
        )
        
        
        self.se = SEBlock(channel=512)
        self.se1 = SEBlock(channel=64)
        self.se2 = SEBlock(channel=128)
        self.se3 = SEBlock(channel=256)

        self.fusion = FeatureFusion(in_channel=512 + size, out_channel=512)
        self.fusion1 = FeatureFusion(in_channel=64 + size, out_channel=64)
        self.fusion2 = FeatureFusion(in_channel=128 + size, out_channel=128)
        self.fusion3 = FeatureFusion(in_channel=256 + size, out_channel=256)

        self.RPMBlock1 = RPMBlock(channels=64)
        self.RPMBlock2 = RPMBlock(channels=128)
        self.RPMBlock3 = RPMBlock(channels=256)
        self.FAM1 = nn.ModuleList([self.RPMBlock1 for i in range(6)])
        self.FAM2 = nn.ModuleList([self.RPMBlock2 for i in range(4)])
        self.FAM3 = nn.ModuleList([self.RPMBlock3 for i in range(2)])

        filters = [64, 128, 256, 512]
        self.decoder4 = DecoderBottleneckLayer(filters[3], filters[2])
        self.decoder3 = DecoderBottleneckLayer(filters[2], filters[1])
        self.decoder2 = DecoderBottleneckLayer(filters[1], filters[0])
        self.decoder1 = DecoderBottleneckLayer(filters[0], filters[0])

        self.final_conv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
        self.final_relu1 = nn.ReLU(inplace=True)
        self.final_conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.final_relu2 = nn.ReLU(inplace=True)
        self.final_conv3 = nn.Conv2d(32, n_classes, 3, padding=1)
        
        # self.prev_cnn = nn.Conv2d(3, 3, kernel_size=1, padding=0)
        self.prev_cnn = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0, groups=3)

    def forward(self, x):
        b, c, h, w = x.shape
        e0, e1, e2, e3, e4 = self.cnn(x)
        feature_cnn = self.headpool(e4)
        size = 256

        prev_cnn = self.prev_cnn(x)


         # 12 layer
        emb = self.patch_embed(prev_cnn)
        emb = self.transformers[0](emb)
        emb = self.transformers[1](emb)
        emb1 = self.transformers[2](emb)
        feature_tf1 = emb1.permute(0, 2, 1)
        feature_tf1 = feature_tf1.view(b, size, 16, 16)

        emb1 = self.transformers[3](emb1)
        emb1 = self.transformers[4](emb1)
        emb2 = self.transformers[5](emb1)
        feature_tf2 = emb2.permute(0, 2, 1)
        feature_tf2 = feature_tf2.view(b, size, 16, 16)

        emb2 = self.transformers[6](emb2)
        emb2 = self.transformers[7](emb2)
        emb3 = self.transformers[8](emb2)
        feature_tf3 = emb3.permute(0, 2, 1)
        feature_tf3 = feature_tf3.view(b, size, 16, 16)

        emb3 = self.transformers[9](emb3)
        emb3 = self.transformers[10](emb3)
        emb4 = self.transformers[11](emb3)

        feature_tf = emb4.permute(0, 2, 1)
        feature_tf = feature_tf.view(b, size, 16, 16)

        feature_cat = self.fusion(feature_cnn, feature_tf)
        feature_cat = self.se(feature_cat)

        e1 = self.fusion1(e1, feature_tf1)
        e2 = self.fusion2(e2, feature_tf2)
        e3 = self.fusion3(e3, feature_tf3)
        e1 = self.se1(e1)
        e2 = self.se2(e2)
        e3 = self.se3(e3)
        for i in range(2):
            e3 = self.FAM3[i](e3)
        for i in range(4):
            e2 = self.FAM2[i](e2)
        for i in range(6):
            e1 = self.FAM1[i](e1)

        d4 = self.decoder4(feature_cat) + e3
        d3 = self.decoder3(d4) + e2
        d2 = self.decoder2(d3) + e1
        out1 = self.final_conv1(d2)
        out1 = self.final_relu1(out1)
        out = self.final_conv2(out1)
        out = self.final_relu2(out)
        out = self.final_conv3(out)
        return out

# Loss & Compile

In [15]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        #print(inputs, targets)
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice
class DiceFocalLoss(nn.Module):
    def __init__(self, gamma=2., alpha=0.25, weight=None, size_average=True):
        super(DiceFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, inputs, targets, smooth=1):
        # Focal Loss
        if not (targets.size() == inputs.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(targets.size(), inputs.size()))

        max_val = (-inputs).clamp(min=0)
        loss = inputs - inputs * targets + max_val + ((-max_val).exp() + (-inputs - max_val).exp()).log()

        invprobs = F.logsigmoid(-inputs * (targets * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        focal_loss = self.alpha * (1 - targets) * loss + (1 - self.alpha) * targets * loss
        focal_loss = focal_loss.sum()

        # Dice Loss
        inputs = torch.sigmoid(inputs) 
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        
        return 0.5* dice_loss + 0.5*focal_loss
        
class DiceBCELoss(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()

    def forward(self, outputs, targets):
        return 0.25 * self.bce(outputs, targets) + 0.75 * self.dice(outputs, targets)

# 학습

In [16]:
max_test_sample = 5

In [17]:
def train(dataloader, model, loss_fn, optimizer, scheduler):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.float().to(device), y.float().to(device)
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 50 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    scheduler.step()

In [18]:
import torchmetrics
import json
accuracy = torchmetrics.Accuracy(task='binary').to(device)
precision = torchmetrics.Precision(task='binary').to(device)
recall = torchmetrics.Recall(task='binary').to(device)
f1= torchmetrics.F1Score(task='binary').to(device)
Auroc = torchmetrics.AUROC(task='binary').to(device)
IoU = torchmetrics.JaccardIndex(task='binary').to(device)
Dice = torchmetrics.Dice().to(device)

previous_metrics = [(0,0,0,0)] 
#old_model_metrics = json.load(open('../../DTrC-Net/history.json', 'r'))

def test(dataloader, model, loss_fn, current_epoch):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0,0

    test_loss /= size
    correct /= size
    index_x = 0
    total_pred = torch.Tensor().to(device)
    total_y = torch.Tensor().bool().to(device)
    
    if current_epoch % 10 == 0:
        fig, (ax) = plt.subplots(4, max_test_sample, figsize=(40,12))
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.float().to(device), y.to(device)
            pred = model(X)
            pred = pred.to(device)
            pred = torch.sigmoid(pred)
            total_pred = torch.cat((total_pred, pred), dim=0) 
            total_y = torch.cat((total_y, y), dim=0)
            if current_epoch % 10 == 0:
                if index_x < max_test_sample:
                    heatmap_np = pred.cpu().detach().numpy()
                    heatmap_np = heatmap_np[0, 0, :, :]
                    # nomalize
                    heatmap_np = (heatmap_np - heatmap_np.min()) / (heatmap_np.max() - heatmap_np.min()) 

                    fig.colorbar(ax[0,index_x].imshow(X.cpu().detach().numpy()[0, 0, :, :]), ax=ax[0,index_x])
                    fig.colorbar(ax[1,index_x].imshow(y.cpu().detach().numpy()[0, 0, :, :]), ax=ax[1,index_x])
                    fig.colorbar(ax[2,index_x].imshow(heatmap_np), ax=ax[2,index_x])
                    ax[3,2].set_title("> 0.5")
                    fig.colorbar(ax[3,index_x].imshow(heatmap_np > 0.5), ax=ax[3,index_x])

                    ax[0,index_x].axis('off')
                    ax[1,index_x].axis('off')
                    ax[2,index_x].axis('off')
                    ax[3,index_x].axis('off')
                index_x += 1
                
    if current_epoch % 10 == 0:
        plt.show()
    
    score_accuracy = accuracy(total_pred, total_y).item()
    score_loss = loss_fn(total_pred, total_y.float()).item()
    score_precision = precision(total_pred, total_y).item()
    score_recall = recall(total_pred, total_y).item()
    score_f1 = f1(total_pred, total_y).item()
    score_IoU = IoU(total_pred, total_y).item()
    score_Dice = Dice(total_pred, total_y).item()
    
    
    print(f"Accuracy: {(100*score_accuracy):>8f}%({100*(score_accuracy-previous_metrics[-1][0]):+g}%p)")
    print(f"Loss: {score_loss:>8f}({score_loss - previous_metrics[-1][1]:+g})")
    print(f"Precision: {score_precision:>8f}")
    print(f"Recall: {score_recall:>8f}")
    print(f"F1: {score_f1:>8f}")
    print(f"IoU: {score_IoU:>8f}")
    print(f"Dice: {score_Dice:>8f}")

    if current_epoch == 0:
        previous_metrics.clear()
    previous_metrics.append((score_accuracy, score_loss, score_IoU, score_Dice, score_precision, score_recall, score_f1))
    print(f"Global precision: {sum([m[4] for m in previous_metrics])/len(previous_metrics):>8f}")
    print(f"Global recall: {sum([m[5] for m in previous_metrics])/len(previous_metrics):>8f}")
    print(f"Global F1: {sum([m[6] for m in previous_metrics])/len(previous_metrics):>8f}")
    
    if current_epoch % 10 == 0:
        epochs = range(1, len(previous_metrics) + 1)  # Increment the upper limit by 1
        plt.plot(epochs, [m[0] for m in previous_metrics], label='Accuracy')
        plt.plot(epochs, [m[1] for m in previous_metrics], label='Loss')
        plt.plot(epochs, [m[2] for m in previous_metrics], label='IoU')
        plt.plot(epochs, [m[3] for m in previous_metrics], label='Dice')
        plt.xlabel('Epochs')
        plt.ylabel('Metrics')
        plt.xlim([1, len(epochs)])
        plt.xticks(range(1, len(epochs)))  # Set the ticks to match the updated range
        plt.title('Metrics Progression')
        plt.legend()
        plt.show()

        plt.scatter([m[5] for m in previous_metrics], [m[4] for m in previous_metrics])
        plt.title('Precision-Recall Curve')
        plt.xlabel('Recall')
        plt.ylabel('Precision')

        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.show()

        plt.plot(epochs, [m[1] for m in previous_metrics][:len(previous_metrics)], label='Current model')
        #plt.plot(epochs, [m[1] for m in old_model_metrics][:len(previous_metrics)], label='Old model')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.xlim([1, len(epochs)])
        plt.xticks(range(1, len(epochs)))  # Set the ticks to match the updated range
        plt.title('Loss comparison')
        plt.legend()
        plt.show()


    print("----------------------------------------------------------------------------------")
    #json.dump(previous_metrics, open(f"history.json", "w"))

In [None]:
import torch
from torchsummaryX import summary

# Instantiate your model and move it to GPU
model = PCTCNet().to(device)

# Print model summary
summary(model, torch.zeros((1, 3, 256, 256)).to(device))


In [None]:
import time
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

# model = get_efficientunet_b7(out_channels=1, concat_input=True, pretrained=True).to(device)

# inp = torch.randn(1, 3, 512, 512)
model = PCTCNet().to(device)
#loss_fn = nn.BCEWithLogitsLoss().to(device)
loss_fn = DiceBCELoss().to(device)

start_time = time.time()
# Define your optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
# Set the epoch milestones at which to adjust the learning rate
milestones = [20, 50, 100]

epochs = 300
# Set the factor by which the learning rate will decay
gamma = 0.1

# Calculate the step size based on the number of epochs
step_size = int(epochs / len(milestones))


lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

In [None]:

for current_epoch in range(epochs):
    print(f"Epoch {current_epoch+1}\n-------------------------------")
    train(train_data_loader, model, loss_fn, optimizer,lr_scheduler)
    test(test_data_loader, model, loss_fn, current_epoch)
    
print("Done!")