In [1]:
from model import UNET

In [2]:
from importnb import Notebook

with Notebook():
    from CustomDataClass import CarsDataSet
    from UtilsEdited import save_checkpoint, load_checkpoint, get_loaders, check_accuracy

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm

###### HyperParameters

In [4]:
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 16
NUM_EPOCHS = 15
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
LOAD_MODEL = False

In [5]:
print(DEVICE)

cuda


###### Dirs

In [6]:
train_dir = "C:\\Users\\georg\\Desktop\\inmind-material\\Week-6\\Session-1\\train"
mask_dir = "C:\\Users\\georg\\Desktop\\inmind-material\\Week-6\\Session-1\\train_masks"

###### Data Transform Function

In [7]:
transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

###### Train

In [8]:
def train_fn(loader, model, optimizer, loss_fn, epoch, num_epochs):
    
    loop = tqdm(loader,total=len(loader)) #leave=false
    
    for (data, targets) in loop:
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward pass
        predictions = model.forward(data)
        loss = loss_fn(predictions, targets)

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # update tqdm loop
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss=loss.item())

In [9]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)

In [10]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [11]:
train_loader, val_loader = get_loaders( 
    train_dir,
    mask_dir,
    BATCH_SIZE,
    transform,
)

In [12]:
if LOAD_MODEL:
    load_checkpoint(torch.load('my_checkpoint.pth.tar'), model)
    check_accuracy(val_loader, model, device=DEVICE)

In [13]:
Accuracy=[]
DiceScores = []
for epoch in range(NUM_EPOCHS):
    
    train_fn(train_loader, model, optimizer, loss_fn, epoch, NUM_EPOCHS)
    print("epoch one done")

    # save model
    checkpoint = {'state_dict': model.state_dict(),'optimizer': optimizer.state_dict()}
    save_checkpoint(checkpoint)

    # check accuracy
    acc, dice_score = check_accuracy(val_loader, model, device=DEVICE)
    
    Accuracy.append(acc)
    DiceScores.append(dice_score)
    

  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38641070/39091200 with acc 98.85
Dice score: 0.9725877642631531


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38563232/39091200 with acc 98.65
Dice score: 0.9687594771385193


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38755439/39091200 with acc 99.14
Dice score: 0.9798970222473145


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38798224/39091200 with acc 99.25
Dice score: 0.9823131561279297


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38824450/39091200 with acc 99.32
Dice score: 0.9839661717414856


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38616262/39091200 with acc 98.79
Dice score: 0.9716569781303406


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38832417/39091200 with acc 99.34
Dice score: 0.9843029975891113


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38779642/39091200 with acc 99.20
Dice score: 0.9811540246009827


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38854301/39091200 with acc 99.39
Dice score: 0.9857268929481506


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38860937/39091200 with acc 99.41
Dice score: 0.9860703945159912


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38862875/39091200 with acc 99.42
Dice score: 0.9862521886825562


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38861188/39091200 with acc 99.41
Dice score: 0.9861713647842407


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38877535/39091200 with acc 99.45
Dice score: 0.9870994687080383


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38865218/39091200 with acc 99.42
Dice score: 0.9863752126693726


  0%|          | 0/255 [00:00<?, ?it/s]

epoch one done
=> Saving checkpoint
Got 38881273/39091200 with acc 99.46
Dice score: 0.9873620271682739
