In [1]:
import torch
from torch.cuda.amp import autocast, GradScaler
import torchvision
import torchvision.transforms as T
import nbimporter
from unet import YouNet, UNet
import dataset
import utils
import PIL
from torchvision.transforms.functional import to_pil_image
import os
import numpy as np
import yaml
print(yaml.__file__)
import carvana

c:\Users\Hayden\anaconda3\envs\d2l\lib\site-packages\yaml\__init__.py


In [2]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

NUM_WORKERS = config["hyperparams"]["NUM_WORKERS"]
BATCH_SIZE = config["hyperparams"]["BATCH_SIZE"]
NUM_EPOCHS = config["hyperparams"]["NUM_EPOCHS"]
LEARNING_RATE = float(config["hyperparams"]["LEARNING_RATE"])
WEIGHT_DECAY = float(config["hyperparams"]["WEIGHT_DECAY"])
IMAGE_HEIGHT = config["hyperparams"]["IMAGE_HEIGHT"]
IMAGE_WIDTH = config["hyperparams"]["IMAGE_WIDTH"]
CROP_SIZE = config["hyperparams"]["CROP_SIZE"]

device_str = config.get("device", "auto")
if device_str == "auto":
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
    DEVICE = torch.device(device_str)
    if DEVICE.type == "cuda" and not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available but 'cuda' was specified.")


DATA_DIR = config["paths"]["DATA_DIR"]
OUTPUT_DIR = config["paths"]["OUTPUT_DIR"]
WEIGHTS_DIR = config["paths"]["WEIGHTS_DIR"]

class_weights = np.load(WEIGHTS_DIR)
LOSS_WEIGHTS_TENSOR = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

VOC_COLORMAP = config["voc_dataset"]["VOC_COLORMAP"]
VOC_CLASSES = config["voc_dataset"]["VOC_CLASSES"]

print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
def predict(model, device, image):
    with torch.no_grad():
        prediction = model(image.to(device)).argmax(1, keepdim=True)
    return prediction

def prediction_to_image(prediction, device):
    VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]
    color_map = torch.tensor(VOC_COLORMAP, device=device, dtype=torch.uint8)
    X = prediction.squeeze(1).long()
    rgb = color_map[X, :].permute(0, 3, 1, 2)
    return rgb

In [4]:
def prediction_accuracy(prediction, label):
    if len(prediction.shape) > 1 and prediction.shape[1] > 1:
        prediction = torch.argmax(prediction, axis=1)
    label = label.squeeze(1)
    cmp = prediction.type(label.dtype) == label
    return float(torch.sum(cmp.type(label.dtype)))

def accuracy(prediction, label):
    """
    Example Usage:
        a = torch.tensor([1, 2, 3, 4])
        b = torch.tensor([1, 0, 3, 5])
        accuracy(a, b) -> torch.tensor([1, 0, 1, 0])
    """
    num_matches = (prediction == label).sum().item()
    print(num_matches)
    return num_matches

def compute_accuracy(pred, label):
    pred = pred.argmax(1) if pred.dim() > 1 and pred.size(1) > 1 else pred
    label = label.squeeze(1) if label.dim() == 4 else label
    correct = (pred == label).sum().item()
    total = label.numel()
    return correct / total

In [5]:
def train_batch(loader, model, optimizer, loss_fn, scaler):
    model.train()
    for batch_i, (X, y) in enumerate(loader):
        # print(f'BATCH {batch_i}')
        optimizer.zero_grad()
        
        X = X.to(DEVICE)
        y = y.to(DEVICE)

        # Forward propagation
        # with torch.autocast(device_type=DEVICE, dtype=torch.float16):
        # with autocast():
        print('X.shape=', X.shape)
        y_hat = model(X)
        if y.min() < 0 or y.max() >= y_hat.shape[1]:
            print(f"❌ Invalid label value in batch: min={y.min().item()}, max={y.max().item()}")
            raise ValueError("Label contains out-of-bound class indices.")
        loss = loss_fn(y_hat, y.squeeze(1).long())
        
        # Backward propagation
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss_sum = loss.sum()
        train_accuracy_sum = prediction_accuracy(y_hat, y)

        return train_loss_sum, train_accuracy_sum

        # total_loss += loss.item()
        # total_correct += (y_hat.argmax(dim=1) == y).sum().item()
        # return total_loss / len(train_loader), total_correct / len(train_loader.dataset)

In [6]:
# def train(model, transforms, batch_size):
def train():
    model = YouNet(in_channels=3, out_channels=21).to(DEVICE)
    # model = UNet(in_channels=3, out_channels=21).to(DEVICE)
    train_transform = T.Compose([
        T.ToTensor()
    ])
    validation_transform = T.Compose([
        T.ToTensor()
    ])
    test_transform = T.Compose([
        T.ToTensor()
    ])
    train_loader = dataset.get_data_loader(DATA_DIR, train_transform, train_transform, CROP_SIZE, 'train', BATCH_SIZE)
    validation_loader = dataset.get_data_loader(DATA_DIR, train_transform, train_transform, CROP_SIZE, 'val', BATCH_SIZE)
    
    # Initialize weights
    # def init_weights(module):
    #     if type(module) in [torch.nn.Linear, torch.nn.Conv2d]:
    #         torch.nn.init.normal_(module.weight, std=0.01)
    # model.apply(init_weights)

    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scaler = GradScaler()
    
    for epoch in range(NUM_EPOCHS+1):
        print(f'EPOCH {epoch}:')

        model.train()
        loss, accuracy = train_batch(train_loader, model, optimizer, loss_fn, scaler)
        print(f'Training loss: {loss}')#\nPrediction accuracy: {accuracy}')

        # checkpoint = {
        #     'model': model.state_dict(),
        #     'optimizer': optimizer.state_dict()
        # }
        # utils.save_checkpoint(checkpoint)

        # TODO: Complete the predictions accuracy function and save predictions to disc.
        # predictions accuracy
        
        val_correct, val_total = 0, 0
        with torch.no_grad():
                model.eval()
                elements = None
                for idx, (val_images, val_masks) in enumerate(validation_loader):
                    val_images = val_images.to(DEVICE)
                    val_masks = val_masks.to(DEVICE)
                    
                    prediction = predict(model, DEVICE, val_images)
                    if elements is None:
                        elements = torch.unique(prediction)
                    else:
                        elements = torch.cat((elements, torch.unique(prediction)), dim=0)

                    val_correct += (prediction == val_masks).sum().item()
                    val_total += val_masks.numel()
                    # valid_acc = prediction_accuracy(prediction, val_masks)
                    # print(f'Validation accuracy:  {valid_acc}')


                    if epoch % 5 == 0:
                        prediction_image = prediction_to_image(prediction, DEVICE)

                        image_path = os.path.join(OUTPUT_DIR, f'images_{epoch}')
                        if not os.path.exists(image_path):
                            os.mkdir(image_path)
                        
                        for pred_image, val_image, val_mask in zip(prediction_image, val_images, val_masks):
                            pred_image = to_pil_image(pred_image.cpu())
                            pred_image.save(os.path.join(image_path, f'{idx}_prediction.png'))

                            val_image = dataset.denormalize(val_image.cpu())
                            val_image = to_pil_image(val_image.cpu())
                            val_image.save(os.path.join(image_path, f'{idx}_image.png'))

                            val_mask = dataset.mask_to_image(val_mask.cpu(), VOC_COLORMAP)
                            val_mask = to_pil_image(val_mask.squeeze(0).cpu())
                            val_mask.save(os.path.join(image_path, f'{idx}_mask.png'))
                
                valid_acc = val_correct / val_total
                print(f'Validation accuracy:  {valid_acc}')
                print(torch.unique(elements))
    # for test_image, test_mask in test_loader:
    #     prediction = utils.prediction_to_image(utils.predict(test_loader, model, DEVICE, test_image), DEVICE)
    #     test_acc = utils.prediction_accuracy(prediction, test_mask)
    #     print(f'Test Accuracy: {test_acc}')

        
        
        # utils.save_predictions(validation_loader, model, data_dir='predictions/', device=DEVICE)

In [7]:
train()

in_channels=3, out_channels=64
in_channels=64, out_channels=128
in_channels=128, out_channels=256
in_channels=256, out_channels=512
in_channels=512, out_channels=1024
in_channels=1024, out_channels=512
in_channels=512, out_channels=256
in_channels=256, out_channels=128
in_channels=128, out_channels=64
4
4
EPOCH 0:
X.shape= torch.Size([1, 3, 256, 256])
Training loss: 3.0390970706939697
Validation accuracy:  0.000659942626953125
tensor([1, 3, 6, 7], device='cuda:0')
EPOCH 1:
X.shape= torch.Size([1, 3, 256, 256])
Training loss: 3.0381815433502197
Validation accuracy:  0.0341644287109375
tensor([1, 3, 6, 7], device='cuda:0')
EPOCH 2:
X.shape= torch.Size([1, 3, 256, 256])
Training loss: 3.037264585494995
Validation accuracy:  0.08246612548828125
tensor([1, 3, 7], device='cuda:0')
EPOCH 3:
X.shape= torch.Size([1, 3, 256, 256])
Training loss: 3.036228656768799
Validation accuracy:  0.07462692260742188
tensor([1, 3], device='cuda:0')
EPOCH 4:
X.shape= torch.Size([1, 3, 256, 256])
Training loss