In [None]:
import numpy as np
import pandas as pd
from dataclasses import dataclass
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image, ImageFile

In [None]:
@dataclass
class Config:
    train_csv_filepath: str
    test_csv_filepath: str
    submission_filepath: str
    images_root_folder: str
    training_output_folder: str
    saved_weights_filepath: str
    device: str

    def __post_init(self):
        """ For configuration variables that are shared across environments """
        ...

    # noinspection PyAttributeOutsideInit
    def init(self, training):
        """ Adjust configuration setup for training vs inference """
        self.training = training

        if self.training:
            os.makedirs(self.training_output_folder, exist_ok=True)

        ...

config: Config = None
""" Set to environment-relevant config before training/inference """;

In [None]:
local_config = Config(
    train_csv_filepath='data/train.csv',
    test_csv_filepath='data/test.csv',
    submission_filepath='data/submission.csv',
    images_root_folder='data/images',
    training_output_folder='data_gen/training_output/',
    saved_weights_filepath='data_gen/training_output/model_weights.pth',
    device='cpu',
)
kaggle_config = Config(
    train_csv_filepath='TODO',
    test_csv_filepath='TODO',
    images_root_folder='TODO',
    submission_filepath='TODO',
    training_output_folder='TODO',
    saved_weights_filepath='TODO',
    device='cuda',
)

In [None]:
imagenet_mean_array = np.array([0.485, 0.456, 0.406], dtype=np.float32)
imagenet_std_array = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def gpu_tensor(numpy_array):
    return torch.tensor(numpy_array, device=config.device)

def visualize_image(image_tensor):
    """ Input tensor should be on gpu """
    image = denormalize(image_tensor, config.channelwise_imagenet_mean_gpu_tensor, config.channelwise_imagenet_std_gpu_tensor)
    image = torch.clamp(image, 0, 1)
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * 255).astype('uint8')
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def normalize(tensor, mean, std):
    return (tensor - mean) / std

def denormalize(tensor, mean, std):
    return tensor * std + mean

def load_pil_image_from_id(image_id) -> ImageFile.ImageFile:
    return Image.open(config.images_root_folder + image_id + '.jpg')