In [1]:
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from mydataset import MyDataset
from myunet import Unet
from SiamUnet_diff import SiamUnet_diff
from Siamunet_conc import SiamUnet_conc
from tqdm import tqdm
from torch.cuda.amp import autocast
from data_augmentation import RandomRotation, RandomVerticalFlip, RandomHorizontalFlip, Compose, ToTensor

In [2]:
batch_size = 100
lr = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
train_set = MyDataset(root="data", is_train=True, transform=Compose([
    ToTensor(),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    RandomRotation([0, 90, 180, 270]),
]), normalize=transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))

val_set = MyDataset(root="data", is_train=False, transform=Compose([ToTensor()]),
                    normalize=transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)
len(train_set), len(val_set), len(train_loader), len(val_loader)

(3800, 200, 38, 2)

In [4]:
model0 = Unet(6, 2)
model0.cuda()
criterion0 = nn.CrossEntropyLoss()
optim0 = torch.optim.AdamW(model0.parameters(), lr=0.0001, weight_decay=1e-9)
scheduler0 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim0, mode='max', factor=0.1, patience=2,
    verbose=True)
scaler0 = torch.cuda.amp.GradScaler(enabled=True)

In [5]:
model1 = SiamUnet_diff(3, 2)
model1.cuda()

criterion1 = nn.CrossEntropyLoss()
optim1 = torch.optim.AdamW(model1.parameters(), lr=0.0001, weight_decay=1e-9)
scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim1, mode='max', factor=0.1, patience=2,
    verbose=True)
scaler1 = torch.cuda.amp.GradScaler(enabled=True)

In [6]:
model2 = SiamUnet_conc(3, 2)
model2.cuda()

optim2 = torch.optim.AdamW(model2.parameters(), lr=0.0001, weight_decay=1e-9)
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim2, mode='max', factor=0.1, patience=2,
    verbose=True)
scaler2 = torch.cuda.amp.GradScaler(enabled=True)

In [7]:
models = [model0, model1, model2]
criterion = nn.CrossEntropyLoss()
schedulers = [scheduler0, scheduler1, scheduler2]
scalers = [scaler0, scaler1, scaler2]
optims = [optim0, optim1, optim2]

In [8]:
def train_model(epoch):
    for model in models:
        model.train()
    print(f"Epoch {epoch} Training")
    with tqdm(train_loader, desc=str(epoch)) as it:
        for idx, (img1, img2, mask) in enumerate(it, 0):
            img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
            mask = mask.long()
            mask = mask.squeeze(1)
            vote = []
            for model, optim, scaler in zip(models, optims, scalers):
                optim.zero_grad()
                with autocast():
                    outputs = model(img1, img2)
                    loss = criterion(outputs, mask)
                scaler.scale(loss).backward()
                scaler.step(optim)
                scaler.update()
                _, pred = torch.max(outputs.data, 1)
                vote.append(pred)
            vote = torch.stack(vote, dim=0)
            pred = torch.div(torch.sum(vote, dim=0), 2, rounding_mode="trunc")
            p, r, f1, iou = get_index(pred, mask)
            it.set_postfix_str(f"loss: {loss.item(): .4f} p: {p: .4f}  r: {r: .4f}  f1: {f1: .4f}  iou: {iou: .4f}")

In [9]:
def get_index(pred, label):
    eps = 1e-7
    tp = torch.sum(label * pred)
    fp = torch.sum(pred) - tp
    fn = torch.sum(label) - tp

    p = (tp + eps) / (tp + fp + eps)
    r = (tp + eps) / (tp + fn + eps)
    f1 = (2 * p * r + eps) / (p + r + eps)
    iou = (tp + eps) / (tp + fn + fp + eps)
    return p, r, f1, iou


def test_model(epoch):
    for model in models:
        model.eval()
    global max_score
    f1s = 0
    print(f"Epoch {epoch} Testing")
    with torch.no_grad():
        with tqdm(val_loader, desc=str(epoch)) as it:
            for img1, img2, mask in it:
                img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
                mask = mask.squeeze(1)
                vote = []
                for model in models:
                    outputs = model(img1, img2)
                    _, pred = torch.max(outputs.data, 1)
                    vote.append(pred)
                vote = torch.stack(vote, dim=0)
                pred = torch.div(torch.sum(vote, dim=0), 2, rounding_mode="trunc")
                p, r, f1, iou = get_index(pred, mask)
                f1s += f1
                it.set_postfix_str(f"p: {p: .4f}  r: {r: .4f}  f1: {f1: .4f}  iou: {iou: .4f}")
    f1s /= len(val_loader)
    for scheduler in schedulers:
        scheduler.step(f1s)
    print("f1", f1s.item())
    if max_score < f1s:
        max_score = f1s
        print('max_score', max_score.item())

In [10]:
num_epoch = 10
max_score = 0
for epoch in range(0, num_epoch):
    train_model(epoch=epoch)
    test_model(epoch=epoch)
print("completed!")
print('max_score', max_score.item())

Epoch 0 Training


0: 100%|██████████| 38/38 [01:47<00:00,  2.83s/it, loss:  0.5359 p:  0.3198  r:  0.3194  f1:  0.3196  iou:  0.1902]


Epoch 0 Testing


0: 100%|██████████| 2/2 [00:03<00:00,  1.86s/it, p:  0.3978  r:  0.4589  f1:  0.4262  iou:  0.2708]


f1 0.4453760606427993
max_score 0.4453760606427993
Epoch 1 Training


1: 100%|██████████| 38/38 [01:42<00:00,  2.69s/it, loss:  0.4890 p:  0.3792  r:  0.3431  f1:  0.3603  iou:  0.2197]


Epoch 1 Testing


1: 100%|██████████| 2/2 [00:03<00:00,  1.74s/it, p:  0.4839  r:  0.4611  f1:  0.4722  iou:  0.3091]


f1 0.49267218995263473
max_score 0.49267218995263473
Epoch 2 Training


2: 100%|██████████| 38/38 [01:34<00:00,  2.49s/it, loss:  0.4523 p:  0.4389  r:  0.3303  f1:  0.3769  iou:  0.2322]


Epoch 2 Testing


2: 100%|██████████| 2/2 [00:02<00:00,  1.10s/it, p:  0.5622  r:  0.4249  f1:  0.4840  iou:  0.3193]


f1 0.5050421166454132
max_score 0.5050421166454132
Epoch 3 Training


3: 100%|██████████| 38/38 [01:22<00:00,  2.17s/it, loss:  0.4262 p:  0.5675  r:  0.2709  f1:  0.3667  iou:  0.2245]


Epoch 3 Testing


3: 100%|██████████| 2/2 [00:02<00:00,  1.07s/it, p:  0.5938  r:  0.3861  f1:  0.4679  iou:  0.3054]


f1 0.49117201057579135
Epoch 4 Training


4: 100%|██████████| 38/38 [01:20<00:00,  2.13s/it, loss:  0.4667 p:  0.7099  r:  0.2746  f1:  0.3960  iou:  0.2469]


Epoch 4 Testing


4: 100%|██████████| 2/2 [00:02<00:00,  1.05s/it, p:  0.6479  r:  0.3585  f1:  0.4616  iou:  0.3000]


f1 0.48017667487870397
Epoch 5 Training


5: 100%|██████████| 38/38 [01:21<00:00,  2.13s/it, loss:  0.3761 p:  0.6103  r:  0.3084  f1:  0.4098  iou:  0.2577]


Epoch 5 Testing


5: 100%|██████████| 2/2 [00:02<00:00,  1.06s/it, p:  0.6599  r:  0.2999  f1:  0.4124  iou:  0.2598]


Epoch     6: reducing learning rate of group 0 to 1.0000e-05.
Epoch     6: reducing learning rate of group 0 to 1.0000e-05.
Epoch     6: reducing learning rate of group 0 to 1.0000e-05.
f1 0.4267426807142102
Epoch 6 Training


6: 100%|██████████| 38/38 [01:20<00:00,  2.13s/it, loss:  0.4570 p:  0.7392  r:  0.2972  f1:  0.4240  iou:  0.2690]


Epoch 6 Testing


6: 100%|██████████| 2/2 [00:02<00:00,  1.06s/it, p:  0.6672  r:  0.2913  f1:  0.4055  iou:  0.2543]


f1 0.4188410083723795
Epoch 7 Training


7: 100%|██████████| 38/38 [01:20<00:00,  2.13s/it, loss:  0.4124 p:  0.7486  r:  0.2883  f1:  0.4163  iou:  0.2628]


Epoch 7 Testing


7: 100%|██████████| 2/2 [00:02<00:00,  1.04s/it, p:  0.6763  r:  0.2828  f1:  0.3988  iou:  0.2491]


f1 0.4131095043547631
Epoch 8 Training


8: 100%|██████████| 38/38 [01:20<00:00,  2.13s/it, loss:  0.4089 p:  0.7090  r:  0.2535  f1:  0.3735  iou:  0.2296]


Epoch 8 Testing


8: 100%|██████████| 2/2 [00:02<00:00,  1.06s/it, p:  0.6813  r:  0.2910  f1:  0.4078  iou:  0.2561]


Epoch     9: reducing learning rate of group 0 to 1.0000e-06.
Epoch     9: reducing learning rate of group 0 to 1.0000e-06.
Epoch     9: reducing learning rate of group 0 to 1.0000e-06.
f1 0.42131995718744086
Epoch 9 Training


9: 100%|██████████| 38/38 [01:20<00:00,  2.13s/it, loss:  0.4171 p:  0.6967  r:  0.2439  f1:  0.3613  iou:  0.2205]


Epoch 9 Testing


9: 100%|██████████| 2/2 [00:02<00:00,  1.06s/it, p:  0.6827  r:  0.2931  f1:  0.4101  iou:  0.2580]

f1 0.4211871367429604
completed!
max_score 0.5050421166454132



