In [1]:
import torch
from torch.cuda.amp import autocast, GradScaler
import torchvision
import torchvision.transforms as T
import nbimporter
from unet import YouNet
# from dataset import get_data_loader
import dataset
import utils
import PIL
from torchvision.transforms.functional import to_pil_image
import os

In [None]:
NUM_WORKERS = 0
BATCH_SIZE = 32
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-2
IMAGE_HEIGHT = 320
IMAGE_WIDTH = 480
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATA_DIR = 'C:/Users/Hayden/Machine Learning/d2l/d2l-en/pytorch/chapter_computer-vision/data/VOCdevkit/VOC2012/'
OUTPUT_DIR = 'C:/Users/Hayden/Desktop'
LOSS_WEIGHTS = torch.tensor([0.1] + [1.0] * 20).to(DEVICE)
CROP_SIZE = (500, 500)

In [3]:
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']


def predict(model, device, image):
    # print('___predict___')
    # test iterator is from DataLoader for test set.
    # X = test_iterator.dataset.normalize(image).unsqueeze(0)
    # X = test_iterator.dataset.normalize(image)
    # print(f'INPUT X.shape={X.shape}')
    with torch.no_grad():
        prediction = model(image.to(device)).argmax(1, keepdim=True)
    # return prediction.reshape(prediction.shape[1], prediction.shape[2])
    # print(f'prediction.shape={prediction.shape}')
    # print(f'prediction.argmax(1).shape={prediction.argmax(1).shape}')
    # return prediction.unsqueeze(0)
    return prediction

def prediction_to_image(prediction, device):
    # print('___prediction_to_image___')
    VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]
    color_map = torch.tensor(VOC_COLORMAP, device=device, dtype=torch.uint8)
    # print(f'PREDICTION_TO_IMAGE: color_map={color_map.shape} prediction={prediction.shape}')
    X = prediction.squeeze(1).long()
    # print(f'PREDICTION.shape={X.shape} UNIQUE={torch.unique(X)}')
    # print(f'PREDICTION_TO_IMAGE: color_map={color_map.shape} prediction.long={X.shape}')
    # print(f'color_map[X, :]={color_map[X, :].shape}')
    rgb = color_map[X, :].permute(0, 3, 1, 2)
    # print(f'RGB.shape={rgb.shape}')
    return rgb

In [4]:
def prediction_accuracy(prediction, label):
    # print('___prediction_accuracy___')
    # print(f'PREDICTION_ACCURACY: prediction={prediction.shape} label={label.shape}')
    if len(prediction.shape) > 1 and prediction.shape[1] > 1:
        prediction = torch.argmax(prediction, axis=1)
    label = label.squeeze(1)
    # print(f'PREDICTION.SHAPE={prediction.shape} LABEL.SHAPE={label.shape}')
    cmp = prediction.type(label.dtype) == label
    return float(torch.sum(cmp.type(label.dtype)))

def accuracy(prediction, label):
    num_matches = (prediction == label).sum().item()
    print(num_matches)


# a = torch.tensor([1, 2, 3, 4])
# b = torch.tensor([1, 0, 3, 5])
# accuracy(a, b)

In [5]:
def train_batch(loader, model, optimizer, loss_fn, scaler):
    # model.train()
    # total_loss, total_correct = 0, 0
    for batch_i, (X, y) in enumerate(loader):
        # print(f'BATCH {batch_i}')
        optimizer.zero_grad()
        
        X = X.to(DEVICE)
        y = y.to(DEVICE)

        # Forward propagation
        # with torch.autocast(device_type=DEVICE, dtype=torch.float16):
        with autocast():
            print('X.shape=', X.shape)
            y_hat = model(X)
            # print(f'Y_HAT: {y_hat.shape}')
            # loss = loss_fn(y_hat, y)
            # print(f'TRAIN_BATCH: y_hat={y_hat.shape} y.shape={y.shape} y.squeeze(0).long()={y.squeeze(0).long().shape} y.squeeze(1).long()={y.squeeze(1).long().shape}')
            # loss = loss_fn(y_hat, y)
            if y.min() < 0 or y.max() >= y_hat.shape[1]:
                print(f"❌ Invalid label value in batch: min={y.min().item()}, max={y.max().item()}")
                raise ValueError("Label contains out-of-bound class indices.")
            loss = loss_fn(y_hat, y.squeeze(1).long())
        
        # Backward propagation
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss_sum = loss.sum()
        # train_accuracy_sum = utils.prediction_accuracy(y_hat, y)
        # print(f'Y_HAT={y_hat.shape}, Y={y.shape}')
        train_accuracy_sum = prediction_accuracy(y_hat, y)

        return train_loss_sum, train_accuracy_sum

        # total_loss += loss.item()
        # total_correct += (y_hat.argmax(dim=1) == y).sum().item()
        # return total_loss / len(train_loader), total_correct / len(train_loader.dataset)

In [6]:
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

In [7]:
# def train(model, transforms, batch_size):
def train():
    model = YouNet(in_channels=3, out_channels=21).to(DEVICE)
    train_transform = T.Compose([
        T.ToTensor()
    ])
    validation_transform = T.Compose([
        T.ToTensor()
    ])
    test_transform = T.Compose([
        T.ToTensor()
    ])
    train_loader = dataset.get_data_loader(DATA_DIR, train_transform, train_transform, CROP_SIZE, 'train')
    validation_loader = dataset.get_data_loader(DATA_DIR, train_transform, train_transform, CROP_SIZE, 'val')
    
    # Initialize weights
    # def init_weights(module):
    #     if type(module) in [torch.nn.Linear, torch.nn.Conv2d]:
    #         torch.nn.init.normal_(module.weight, std=0.01)
    # model.apply(init_weights)

    loss_fn = torch.nn.CrossEntropyLoss()#ignore_index=0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    # optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scaler = GradScaler()
    
    for epoch in range(NUM_EPOCHS):
        print(f'EPOCH {epoch}:')

        model.train()
        loss, accuracy = train_batch(train_loader, model, optimizer, loss_fn, scaler)
        print(f'Training loss: {loss}')#\nPrediction accuracy: {accuracy}')

        # checkpoint = {
        #     'model': model.state_dict(),
        #     'optimizer': optimizer.state_dict()
        # }
        # utils.save_checkpoint(checkpoint)

        # TODO: Complete the predictions accuracy function and save predictions to disc.
        # predictions accuracy
        
        val_correct, val_total = 0, 0
        with torch.no_grad():
                model.eval()
                # print('___eval___')
                elements = None
                for idx, (val_images, val_masks) in enumerate(validation_loader):
                    val_images = val_images.to(DEVICE)
                    val_masks = val_masks.to(DEVICE)
                    
                    prediction = predict(model, DEVICE, val_images)
                    if elements is None:
                        elements = torch.unique(prediction)
                        print(f'elements.shape={elements.shape}')
                    else:
                        print(f'elements.shape={elements.shape}, torch.unique(prediction).shape={torch.unique(prediction).shape}')
                        elements = torch.cat((elements, torch.unique(prediction)), dim=0)
                        print(f'\telements.shape={elements.shape}')

                    val_correct += (prediction == val_masks).sum().item()
                    val_total += val_masks.numel()
                    # valid_acc = prediction_accuracy(prediction, val_masks)
                    # print(f'Validation accuracy:  {valid_acc}')


                    if epoch % 5 == 0:
                        prediction_image = prediction_to_image(prediction, DEVICE)

                        image_path = os.path.join(OUTPUT_DIR, f'images_{epoch}')
                        if not os.path.exists(image_path):
                            os.mkdir(image_path)
                        
                        for pred_image, val_image, val_mask in zip(prediction_image, val_images, val_masks):
                            pred_image = to_pil_image(pred_image.cpu())
                            pred_image.save(os.path.join(image_path, f'{idx}_prediction.png'))

                            val_image = dataset.denormalize(val_image.cpu())
                            val_image = to_pil_image(val_image.cpu())
                            val_image.save(os.path.join(image_path, f'{idx}_image.png'))

                            val_mask = dataset.mask_to_image(val_mask.cpu(), VOC_COLORMAP)
                            val_mask = to_pil_image(val_mask.squeeze(0).cpu())
                            val_mask.save(os.path.join(image_path, f'{idx}_mask.png'))
                
                valid_acc = val_correct / val_total
                print(f'Validation accuracy:  {valid_acc}')
                print(torch.unique(elements))
    # for test_image, test_mask in test_loader:
    #     prediction = utils.prediction_to_image(utils.predict(test_loader, model, DEVICE, test_image), DEVICE)
    #     test_acc = utils.prediction_accuracy(prediction, test_mask)
    #     print(f'Test Accuracy: {test_acc}')

        
        
        # utils.save_predictions(validation_loader, model, data_dir='predictions/', device=DEVICE)

In [8]:
train()

EPOCH 0:
X.shape= torch.Size([8, 3, 490, 490])
X.shape input=torch.Size([8, 3, 490, 490])
X.shape output=torch.Size([8, 64, 490, 490])
X.shape input=torch.Size([8, 64, 245, 245])
X.shape output=torch.Size([8, 128, 245, 245])
X.shape input=torch.Size([8, 128, 122, 122])
X.shape output=torch.Size([8, 256, 122, 122])
X.shape input=torch.Size([8, 256, 61, 61])
X.shape output=torch.Size([8, 512, 61, 61])
X.shape input=torch.Size([8, 1024, 30, 30])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 60 but got size 61 for tensor number 1 in the list.