In [1]:
import torch.nn as nn
from torchvision import transforms
from mydataset import MyDataset
from torch.utils.data import  DataLoader
import torch
from myunet import Unet
from tqdm import tqdm
from torch.cuda.amp import autocast

In [4]:
batch_size = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
train_set = MyDataset(root="data", is_train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]))

val_set = MyDataset(root="data", is_train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]))

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)
len(train_set), len(val_set), len(train_loader), len(val_loader)

(3800, 200, 38, 2)

In [6]:
model = Unet(6, 2)
model.cuda()

criterion = nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters(), lr=0.005, weight_decay=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode='max', factor=0.5, patience=2,
    verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [7]:
def train_model(epoch):
    model.train()
    print(f"Epoch {epoch} Training")
    with tqdm(train_loader, desc=str(epoch)) as it:
        for idx, (img1, img2, mask) in enumerate(it, 0):
            img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
            mask = mask.long()
            optim.zero_grad()
            with autocast():
                outputs = model(img1, img2)
                loss = criterion(outputs, mask)
            # loss.backward()
            # optim.step()
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            _, pred = torch.max(outputs.data, 1)
            p, r, f1, iou = get_index(pred, mask)
            it.set_postfix_str(f"loss: {loss.item(): .4f} p: {p: .4f}  r: {r: .4f}  f1: {f1: .4f}  iou: {iou: .4f}")
            # if idx % 50 == 49:

In [8]:
def get_index(pred, label):
    eps = 1e-7
    tp = torch.sum(label * pred)
    fp = torch.sum(pred) - tp
    fn = torch.sum(label) - tp

    p = (tp + eps) / (tp + fp + eps)
    r = (tp + eps) / (tp + fn + eps)
    f1 = (2 * p * r + eps) / (p + r + eps)
    iou = (tp + eps) / (tp + fn + fp + eps)
    return p, r, f1, iou


def test_model(epoch):
    model.eval()
    global max_score
    f1s = 0
    print(f"Epoch {epoch} Testing")
    with torch.no_grad():
        with tqdm(val_loader, desc=str(epoch)) as it:
            for img1, img2, mask in it:
                img1, img2, mask = img1.cuda(), img2.cuda(), mask.cuda()
                outputs = model(img1, img2)
                _, pred = torch.max(outputs.data, 1)
                p, r, f1, iou = get_index(pred, mask)
                f1s += f1
                it.set_postfix_str(f"p: {p: .4f}  r: {r: .4f}  f1: {f1: .4f}  iou: {iou: .4f}")
    f1s /= len(val_loader)
    scheduler.step(f1s)
    print("f1", f1s.item())
    if max_score < f1s:
        max_score = f1s
        print('max_score', max_score.item())

In [9]:
num_epoch = 10
max_score = 0
for epoch in range(0, num_epoch):
    train_model(epoch=epoch)
    test_model(epoch=epoch)
print("completed!")
print('max_score', max_score.item())

Epoch 0 Training


0: 100%|██████████| 38/38 [00:30<00:00,  1.23it/s, loss:  0.5027 p:  0.7639  r:  0.1845  f1:  0.2973  iou:  0.1746]


Epoch 0 Testing


0: 100%|██████████| 2/2 [00:01<00:00,  1.45it/s, p:  0.3440  r:  0.9259  f1:  0.5017  iou:  0.3348]


f1 0.5137138375704662
max_score 0.5137138375704662
Epoch 1 Training


1: 100%|██████████| 38/38 [00:30<00:00,  1.26it/s, loss:  0.4898 p:  0.6066  r:  0.4227  f1:  0.4982  iou:  0.3317]


Epoch 1 Testing


1: 100%|██████████| 2/2 [00:01<00:00,  1.44it/s, p:  0.6708  r:  0.6055  f1:  0.6365  iou:  0.4668]


f1 0.6515043247016217
max_score 0.6515043247016217
Epoch 2 Training


2: 100%|██████████| 38/38 [00:30<00:00,  1.25it/s, loss:  0.4759 p:  0.5903  r:  0.4059  f1:  0.4810  iou:  0.3166]


Epoch 2 Testing


2: 100%|██████████| 2/2 [00:01<00:00,  1.44it/s, p:  0.6723  r:  0.5550  f1:  0.6080  iou:  0.4368]


f1 0.6300655790938263
Epoch 3 Training


3: 100%|██████████| 38/38 [00:30<00:00,  1.26it/s, loss:  0.5063 p:  0.6604  r:  0.3850  f1:  0.4865  iou:  0.3214]


Epoch 3 Testing


3: 100%|██████████| 2/2 [00:01<00:00,  1.49it/s, p:  0.6948  r:  0.5751  f1:  0.6294  iou:  0.4592]


f1 0.6485055144982227
Epoch 4 Training


4: 100%|██████████| 38/38 [00:30<00:00,  1.25it/s, loss:  0.4610 p:  0.7500  r:  0.5408  f1:  0.6284  iou:  0.4582]


Epoch 4 Testing


4: 100%|██████████| 2/2 [00:01<00:00,  1.45it/s, p:  0.6603  r:  0.6238  f1:  0.6415  iou:  0.4722]


f1 0.6602825638084134
max_score 0.6602825638084134
Epoch 5 Training


5: 100%|██████████| 38/38 [00:30<00:00,  1.24it/s, loss:  0.4741 p:  0.5580  r:  0.4473  f1:  0.4965  iou:  0.3303]


Epoch 5 Testing


5: 100%|██████████| 2/2 [00:01<00:00,  1.47it/s, p:  0.7351  r:  0.4597  f1:  0.5656  iou:  0.3943]


f1 0.5965424098869234
Epoch 6 Training


6: 100%|██████████| 38/38 [00:30<00:00,  1.25it/s, loss:  0.4597 p:  0.6937  r:  0.4993  f1:  0.5807  iou:  0.4091]


Epoch 6 Testing


6: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s, p:  0.7425  r:  0.4681  f1:  0.5742  iou:  0.4027]


f1 0.5909715696462075
Epoch 7 Training


7: 100%|██████████| 38/38 [00:30<00:00,  1.26it/s, loss:  0.4394 p:  0.7699  r:  0.5537  f1:  0.6441  iou:  0.4751]


Epoch 7 Testing


7: 100%|██████████| 2/2 [00:01<00:00,  1.46it/s, p:  0.6923  r:  0.5789  f1:  0.6305  iou:  0.4604]


Epoch     8: reducing learning rate of group 0 to 2.5000e-03.
f1 0.6559516876992525
Epoch 8 Training


8: 100%|██████████| 38/38 [00:30<00:00,  1.25it/s, loss:  0.4716 p:  0.7360  r:  0.4432  f1:  0.5533  iou:  0.3824]


Epoch 8 Testing


8: 100%|██████████| 2/2 [00:01<00:00,  1.45it/s, p:  0.7152  r:  0.5390  f1:  0.6147  iou:  0.4438]


f1 0.6389766343022689
Epoch 9 Training


9: 100%|██████████| 38/38 [00:30<00:00,  1.25it/s, loss:  0.4549 p:  0.6705  r:  0.5752  f1:  0.6192  iou:  0.4484]


Epoch 9 Testing


9: 100%|██████████| 2/2 [00:01<00:00,  1.46it/s, p:  0.7276  r:  0.5020  f1:  0.5941  iou:  0.4226]

f1 0.6256118832324173
completed!
max_score 0.6602825638084134



