In [1]:
import sys
import matplotlib.pyplot as plt

import torch
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from torchview import draw_graph
from network import SampleModel
from dataset import SampleDataset
from sklearn.model_selection import train_test_split
from pathlib import Path
from pa228_tools import train, validate
import glob

In [2]:
def loss_batch(model, loss_func, xb, yb, dev, opt=None):
    
    xb, yb = xb.to(dev), yb.to(dev)
    pred = model(xb)
    print('got_out')
    yb = yb.argmax(dim=3)
    # print(yb.shape)
    # print(pred.shape)
    print('loss')
    loss = loss_func(pred, yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)


def train(model, train_dl, loss_func, dev, opt):
        
        model.train()
        loss, size = 0, 0
        for b_idx, (xb, yb) in tqdm(enumerate(train_dl), total=len(train_dl), leave=False):
            print('batching')
            b_loss, b_size = loss_batch(model, loss_func, xb, yb, dev, opt)
            print('done batching')

            loss += b_loss * b_size
            size += b_size
            
        return loss / size
    
    
def validate(model, valid_dl, loss_func, dev, opt=None):
        
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb, dev) for xb, yb in valid_dl]
            )
            
        return np.sum(np.multiply(losses, nums)) / np.sum(nums)

In [3]:
def fit(net, batch_size, epochs, trainloader, validloader, loss_fn, optimizer, device):
    train_losses = []
    validation_losses = []

    for epoch in tqdm(range(epochs), 'epochs'):
        print('training')
        loss = train(net, trainloader, loss_fn, device, optimizer)
        print('validating')
        val_loss = validate(net, validloader, loss_fn, device)

        train_losses.append(loss)
        validation_losses.append(val_loss)
        print(f'epoch {epoch+1}/{epochs}, loss: {loss : .05f}, validation loss: {val_loss:.05f}')

      
    print('Training finished!')
    return train_losses, validation_losses

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Computing with {}!'.format(device))


Computing with cpu!


In [None]:

# config dictionary
config = {
'batch_size': 2,
'epoch': 1,
'num_workers': 1,
'dropout': 0.5,
'lr': 0.0001,
'optimizer':'Adam',
'img_size': 128,
'n_classes': 2
}

PATH = Path('{}'.format('data'), 'data_seg_public')
img_dir = PATH / 'img'
mask_dir = PATH / 'mask'
img_files = glob.glob("{}/*/*.png".format(img_dir))
mask_files = glob.glob("{}/*/*.png".format(mask_dir))
df = pd.DataFrame({'img': img_files, 'mask': mask_files})

import albumentations as A
from albumentations.pytorch import ToTensorV2

transforms = A.Compose([
                        A.SmallestMaxSize (512),
                        A.CenterCrop(512, 1024),
                        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                        ToTensorV2(),
                        ]   
                    )

train_df, valid_df = train_test_split(df, test_size=.3, random_state=2)
traindataset, valdataset = SampleDataset(train_df, transforms=transforms), SampleDataset(valid_df, transforms=transforms)

trainloader = torch.utils.data.DataLoader(traindataset,
                    batch_size=config['batch_size'],
                    shuffle=False,
                    num_workers=config['num_workers'])

valloader = torch.utils.data.DataLoader(valdataset,
                    batch_size=config['batch_size'],
                    shuffle=False,
                    num_workers=config['num_workers'])


In [7]:


net = SampleModel(num_class=8)
# input_sample = torch.zeros((1, 512, 1024))
# draw_network_architecture(net, input_sample)

# define optimizer and learning rate
optimizer = torch.optim.Adam(net.parameters(), lr=config['lr'])

# define loss function
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=2)

# train the network for three epochs
tr_losses, val_losses = fit(net, config['batch_size'], config['epoch'], trainloader, valloader, loss_fn, optimizer, device)


epochs:   0%|          | 0/1 [00:00<?, ?it/s]

training


  0%|          | 0/1143 [00:00<?, ?it/s]

batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching
got_out
loss
done batching
batching


KeyboardInterrupt: 