In [1]:
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast
from mydataset import MyDataset
from SiamUnet_diff import SiamUnet_diff
from data_augmentation import RandomRotation, RandomVerticalFlip, RandomHorizontalFlip, Compose, ToTensor

In [2]:
batch_size = 100
lr = 0.001
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
train_set = MyDataset(root="data", is_train=True, transform=Compose([
    ToTensor(),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    RandomRotation([0, 90, 180, 270]),
]), normalize=transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))

val_set = MyDataset(root="data", is_train=False, transform=Compose([ToTensor()]),
                    normalize=transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)))

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)
len(train_set), len(val_set), len(train_loader), len(val_loader)

(3800, 200, 38, 2)

In [4]:
model = SiamUnet_diff(3, 2)
model.cuda()

criterion = nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode='max', factor=0.1, patience=2,
    verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [5]:
def train_model(epoch):
    model.train()
    print(f"Epoch {epoch} Training")
    with tqdm(train_loader, desc=str(epoch)) as it:
        for idx, (img1, img2, mask) in enumerate(it, 0):
            img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
            optim.zero_grad()
            mask = mask.long()
            with autocast():
                outputs = model(img1, img2)
                mask = mask.squeeze(1)
                loss = criterion(outputs, mask)
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            _, pred = torch.max(outputs.data, 1)
            p, r, f1, iou = get_index(pred, mask)
            it.set_postfix_str(f"loss: {loss.item(): .4f} p: {p: .4f}  r: {r: .4f}  f1: {f1: .4f}  iou: {iou: .4f}")

In [6]:
def get_index(pred, label):
    eps = 1e-7
    tp = torch.sum(label * pred)
    fp = torch.sum(pred) - tp
    fn = torch.sum(label) - tp

    p = (tp + eps) / (tp + fp + eps)
    r = (tp + eps) / (tp + fn + eps)
    f1 = (2 * p * r + eps) / (p + r + eps)
    iou = (tp + eps) / (tp + fn + fp + eps)
    return p, r, f1, iou


def test_model(epoch):
    model.eval()
    global max_score
    f1s = 0
    print(f"Epoch {epoch} Testing")
    with torch.no_grad():
        with tqdm(val_loader, desc=str(epoch)) as it:
            for img1, img2, mask in it:
                img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
                outputs = model(img1, img2)
                _, pred = torch.max(outputs.data, 1)
                mask = mask.squeeze(1)
                p, r, f1, iou = get_index(pred, mask)
                f1s += f1
                it.set_postfix_str(f"p: {p: .4f}  r: {r: .4f}  f1: {f1: .4f}  iou: {iou: .4f}")
    f1s /= len(val_loader)
    scheduler.step(f1s)
    print("f1", f1s.item())
    if max_score < f1s:
        max_score = f1s
        print('max_score', max_score.item())

In [7]:
num_epoch = 10
max_score = 0
for epoch in range(0, num_epoch):
    train_model(epoch=epoch)
    test_model(epoch=epoch)
print("completed!")
print('max_score', max_score.item())

Epoch 0 Training


0: 100%|██████████| 38/38 [01:23<00:00,  2.21s/it, loss:  0.5302 p:  0.6012  r:  0.2345  f1:  0.3374  iou:  0.2029]


Epoch 0 Testing


0: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it, p:  0.5349  r:  0.4218  f1:  0.4716  iou:  0.3086]


f1 0.49367518315337927
max_score 0.49367518315337927
Epoch 1 Training


1: 100%|██████████| 38/38 [01:16<00:00,  2.03s/it, loss:  0.5076 p:  0.6351  r:  0.2903  f1:  0.3985  iou:  0.2488]


Epoch 1 Testing


1: 100%|██████████| 2/2 [00:02<00:00,  1.41s/it, p:  0.6482  r:  0.3557  f1:  0.4594  iou:  0.2982]


f1 0.4807392024009081
Epoch 2 Training


2: 100%|██████████| 38/38 [01:10<00:00,  1.84s/it, loss:  0.4682 p:  0.5477  r:  0.4156  f1:  0.4726  iou:  0.3094]


Epoch 2 Testing


2: 100%|██████████| 2/2 [00:01<00:00,  1.36it/s, p:  0.6360  r:  0.4517  f1:  0.5282  iou:  0.3589]


f1 0.5531057646337669
max_score 0.5531057646337669
Epoch 3 Training


3: 100%|██████████| 38/38 [00:57<00:00,  1.52s/it, loss:  0.4652 p:  0.5961  r:  0.4232  f1:  0.4950  iou:  0.3289]


Epoch 3 Testing


3: 100%|██████████| 2/2 [00:01<00:00,  1.39it/s, p:  0.6868  r:  0.4087  f1:  0.5125  iou:  0.3445]


f1 0.5341498207081414
Epoch 4 Training


4: 100%|██████████| 38/38 [00:55<00:00,  1.46s/it, loss:  0.4732 p:  0.6483  r:  0.4554  f1:  0.5350  iou:  0.3652]


Epoch 4 Testing


4: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s, p:  0.6868  r:  0.4312  f1:  0.5298  iou:  0.3603]


f1 0.5484711960941804
Epoch 5 Training


5: 100%|██████████| 38/38 [00:55<00:00,  1.46s/it, loss:  0.4570 p:  0.6166  r:  0.5194  f1:  0.5638  iou:  0.3926]


Epoch 5 Testing


5: 100%|██████████| 2/2 [00:01<00:00,  1.43it/s, p:  0.6866  r:  0.4868  f1:  0.5697  iou:  0.3983]


f1 0.5928806022057967
max_score 0.5928806022057967
Epoch 6 Training


6: 100%|██████████| 38/38 [00:55<00:00,  1.47s/it, loss:  0.4733 p:  0.7697  r:  0.4527  f1:  0.5701  iou:  0.3987]


Epoch 6 Testing


6: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s, p:  0.7208  r:  0.4225  f1:  0.5328  iou:  0.3631]


f1 0.5580525248934891
Epoch 7 Training


7: 100%|██████████| 38/38 [00:55<00:00,  1.47s/it, loss:  0.4686 p:  0.7007  r:  0.4593  f1:  0.5548  iou:  0.3839]


Epoch 7 Testing


7: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s, p:  0.7264  r:  0.4070  f1:  0.5217  iou:  0.3529]


f1 0.5465619500283122
Epoch 8 Training


8: 100%|██████████| 38/38 [00:55<00:00,  1.47s/it, loss:  0.4619 p:  0.7282  r:  0.5309  f1:  0.6141  iou:  0.4431]


Epoch 8 Testing


8: 100%|██████████| 2/2 [00:01<00:00,  1.40it/s, p:  0.7590  r:  0.4060  f1:  0.5291  iou:  0.3597]


Epoch     9: reducing learning rate of group 0 to 1.0000e-04.
f1 0.5497141915179241
Epoch 9 Training


9: 100%|██████████| 38/38 [00:56<00:00,  1.49s/it, loss:  0.4715 p:  0.6740  r:  0.4884  f1:  0.5663  iou:  0.3950]


Epoch 9 Testing


9: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s, p:  0.7430  r:  0.4710  f1:  0.5765  iou:  0.4050]

f1 0.5932719428910385
max_score 0.5932719428910385
completed!
max_score 0.5932719428910385



