In [1]:
import torch.utils.model_zoo as modelzoo
from matplotlib import pyplot as plt
from torchvision import transforms
import torch.nn.functional as F
from google.colab import drive
from torchvision import models
from zipfile import ZipFile
from pathlib import Path
from PIL import Image
from torch import nn

import typing as tp
import shutil
import random
import torch
import time
import json
import os

In [2]:
drive.mount("/content/drive/")

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [3]:
# @title <h5>Pathfinder</h5>

class GTEAPaths:
    def __init__(self, path_to_folder: str, shuffle: bool=True,
        seed: tp.Optional[int]=None):
        
        self.path_to_image = "/".join([path_to_folder, "image"])
        self.path_to_mask = "/".join([path_to_folder, "mask"])
        self.folder = Path(self.path_to_image)
        self.paths = self._load_paths()

        if seed is not None:
            random.seed(seed)

        if shuffle is True:
            random.shuffle(self.paths)

    def _load_paths(self) -> tp.List[tp.Tuple[str, str]]:
        paths = []
        for file in self.folder.glob("*"):
            filename = file.name
            if ".mat" in filename:
                continue
            image_path = str(file)
            mask_name = filename[:filename.find(".")] + ".png"
            mask_path = "/".join([self.path_to_mask, mask_name])
            paths.append((image_path, mask_path))
        return paths

    def __getitem__(self, index: tp.Union[int, slice]):
        items = []
        if isinstance(index, slice):
            start = 0 if index.start is None else index.start
            stop = len(self) if index.stop is None else index.stop
            step = 1 if index.step is None else index.step
            for i in range(start, stop, step):
                if i < len(self):
                    items.append(self(i))
        else:
            items.append(self(index))
        return items

    def __len__(self):
        return len(self.paths)

    def __call__(self, index: int):
        return self.paths[index]

class EgoHandPaths:
    def __init__(self, path_to_folder: str, shuffle: bool=True,
        seed: tp.Optional[int]=None):
        
        self.path_to_image = "/".join([path_to_folder, "image"])
        self.path_to_mask = "/".join([path_to_folder, "mask"])
        self.folder = Path(self.path_to_image)
        self.paths = self._load_paths()

        if seed is not None:
            random.seed(seed)

        if shuffle is True:
            random.shuffle(self.paths)

    def _load_paths(self) -> tp.List[tp.Tuple[str, str]]:
        paths = []
        for subfolder in self.folder.glob("*"):
            subname = subfolder.name
            for file in subfolder.glob("*"):
                filename = file.name
                if ".mat" in filename:
                    continue
                image_path = str(file)
                mask_name = filename[6:]
                mask_path = "/".join([self.path_to_mask, subname, mask_name])
                paths.append((image_path, mask_path))
        return paths

    def __getitem__(self, index: tp.Union[int, slice]):
        items = []
        if isinstance(index, slice):
            start = 0 if index.start is None else index.start
            stop = len(self) if index.stop is None else index.stop
            step = 1 if index.step is None else index.step
            for i in range(start, stop, step):
                if i < len(self):
                    items.append(self(i))
        else:
            items.append(self(index))
        return items

    def __len__(self):
        return len(self.paths)

    def __call__(self, index: int):
        return self.paths[index]

class Ego2HandPaths:
    def __init__(self, path_to_folder: str, path_to_background: str, 
        shuffle: bool=True, seed: tp.Optional[int]=None):
        self.bg_folder = Path(path_to_background)
        self.bg_paths = self._load_bg_paths()
        self.folder = Path(path_to_folder)
        self.paths = self._load_paths()

        if seed is not None:
            random.seed(seed)

        if shuffle is True:
            random.shuffle(self.paths)

    def _load_bg_paths(self) -> tp.List[str]:
        files = self.bg_folder.glob('*')
        bg_paths = []
        for file in files:
            bg_paths.append(str(file))
        return bg_paths

    def _get_random_background(self):
        return random.choice(self.bg_paths)

    def _load_paths(self) -> tp.List[tp.Tuple[str, str]]:
        paths = []
        for action in self.folder.glob('*'):
            for sequence in action.glob('*'):
                for files in sequence.glob('*'):
                    images = list(files.glob('*'))
                    if len(images[0].name) == 9:
                        image = str(images[0])
                        mask = str(images[1])
                    else:
                        image = str(images[1])
                        mask = str(images[0])
                    paths.append((image, mask, self._get_random_background()))
        return paths

    def __getitem__(self, index: tp.Union[int, slice]):
        items = []
        if isinstance(index, slice):
            start = 0 if index.start is None else index.start
            stop = len(self) if index.stop is None else index.stop
            step = 1 if index.step is None else index.step
            for i in range(start, stop, step):
                if i < len(self):
                    items.append(self(i))
        else:
            items.append(self(index))
        return items

    def __len__(self):
        return len(self.paths)

    def __call__(self, index: int):
        return self.paths[index]

class UnionPaths:
    def __init__(self, paths: tp.List[tp.Tuple], 
        shuffle: bool=True, seed: tp.Optional[int]=None):

        self.paths = paths

        if seed is not None:
            random.seed(seed)

        if shuffle is True:
            random.shuffle(self.paths)

    def __getitem__(self, index: tp.Union[int, slice]):
        items = []
        if isinstance(index, slice):
            start = 0 if index.start is None else index.start
            stop = len(self) if index.stop is None else index.stop
            step = 1 if index.step is None else index.step
            for i in range(start, stop, step):
                if i < len(self):
                    items.append(self(i))
        else:
            items.append(self(index))
        return items

    def __len__(self):
        return len(self.paths)

    def __call__(self, index: int):
        return self.paths[index]

class AugmentationPaths:
    def __init__(self, paths: tp.List[tp.Tuple], angles: tp.Tuple[int, int]=(-180, 180),
        seed: tp.Optional[int]=None, shuffle: bool=True, augmentate: bool=True):
        
        # [color, rotation, hflip, vflip]
        self.changes = []
        self.paths = []

        if seed is not None:
            random.seed(seed)

        for path in paths:
            self.paths.append(path)
            self.changes.append([0, 0, False, False])
            
            if augmentate is False:
                continue
                
            angle = random.randint(angles[0], angles[1])
            color = random.choice([1, 2])
            hflip = random.choice([True, False])
            vflip = random.choice([True, False])
            self.paths.append(path)
            self.changes.append([color, angle, hflip, vflip])

        if shuffle is True:
            random.shuffle(self.paths)

    def __getitem__(self, index: tp.Union[int, slice]):
        items = []
        if isinstance(index, slice):
            start = 0 if index.start is None else index.start
            stop = len(self) if index.stop is None else index.stop
            step = 1 if index.step is None else index.step
            for i in range(start, stop, step):
                if i < len(self):
                    items.append(self(i))
        else:
            items.append(self(index))
        return items

    def __len__(self):
        return len(self.paths)

    def __call__(self, index: int):
        return (self.paths[index], self.changes[index])



In [4]:
# @title <h5>Loader</h5>

class Loader:
    def __init__(self, data: tp.List[ tp.Union[tp.Tuple[str, str, str], tp.Tuple[str, str]]]):

        """
        Convert pathes to PIL Images
        """

        self.aug_params = []
        self.images = []
        self.masks = []
        for item in data:
            collection = item[0]
            if len(collection) == 2:
                image = Image.open(collection[0])
                mask = Image.open(collection[1])
                self.images.append(image)
                self.masks.append(mask)
            elif len(collection) == 3:
                image = Image.open(collection[0])
                mask = Image.open(collection[1])
                background = Image.open(collection[2])
                image = self.paste(image, background)
                self.images.append(image)
                self.masks.append(mask)
            self.aug_params.append(item[1])

    def paste(self, image, background):
        background = background.resize(image.size)        
        background.paste(image, (0,0), image)
        return background

    def __getitem__(self, index: tp.Union[int, slice]):
        items = []
        if isinstance(index, slice):
            start = 0 if index.start is None else index.start
            stop = len(self) if index.stop is None else index.stop
            step = 1 if index.step is None else index.step
            for i in range(start, stop, step):
                if i < len(self):
                    items.append(self(i))
        else:
            items.append(self(index))
        return items

    def __call__(self, index: int):
        return (self.aug_params[index], (self.images[index], self.masks[index]))

    def __len__(self):
        return len(self.images)


In [5]:
# @title <h5>Preprocessor</h5>

class AugPreprocessor:
    def __init__(self, resize: tp.Tuple[int, int]):
        self.resize = resize

        self.to_gray = transforms.Grayscale(3)
        self.jitter = transforms.ColorJitter(brightness=.5, hue=.3)
        
        self.x_preprocessor = transforms.Compose([
            transforms.Resize(self.resize),
            transforms.ToTensor()
        ])

        self.y_preprocessor = transforms.Compose([
            transforms.Resize(self.resize),
            transforms.ToTensor(),
        ])

    def preprocess(self, collection, aug_params, bitwise: bool=False):
        
        color = aug_params[0]
        angle = aug_params[1]
        hflip = aug_params[2]
        vflip = aug_params[3]

        image = self.x_preprocessor(collection[0])
        image = image.unsqueeze(0)
        mask = self.y_preprocessor(collection[1])
        mask = mask.unsqueeze(0).type(torch.BoolTensor)

        if color==1:
            image = self.to_gray(image)
        elif color==2:
            image = self.jitter(image)
        if angle != 0:
            image = transforms.functional.rotate(image, angle)
            mask = transforms.functional.rotate(mask, angle)
        if hflip:
            image = transforms.functional.hflip(image)
            mask = transforms.functional.hflip(mask)
        if vflip:
            image = transforms.functional.vflip(image)
            mask = transforms.functional.vflip(mask)

        if bitwise is True:
            reverse_mask = torch.bitwise_not(mask)
            mask = torch.cat([mask, reverse_mask], dim=1)
        return image, mask

    def __call__(self, data, bitwise: bool=False):
        images = []
        masks = []
        for item in data:
            aug_param, collection = item
            image, mask = self.preprocess(collection, aug_param, bitwise)
            images.append(image)
            masks.append(mask)
        x = torch.cat(images, dim=0)
        y = torch.cat(masks, dim=0)
        return x, y

class Preprocessor:
    def __init__(self, resize: tp.Tuple[int, int]):
        self.resize = resize
        
        self.x_preprocessor = transforms.Compose([
            transforms.Resize(self.resize),
            transforms.ToTensor()
        ])

        self.y_preprocessor = transforms.Compose([
            transforms.Resize(self.resize),
            transforms.ToTensor(),
        ])

    def preprocess(self, collection, bitwise: bool=False):
        image = self.x_preprocessor(collection[0])
        image = image.unsqueeze(0)
        mask = self.y_preprocessor(collection[1])
        mask = mask.unsqueeze(0).type(torch.BoolTensor)
        if bitwise is True:
            reverse_mask = torch.bitwise_not(mask)
            mask = torch.cat([mask, reverse_mask], dim=1)
        return image, mask

    def __call__(self, data, bitwise: bool=False):
        images = []
        masks = []
        for collection in data:
            image, mask = self.preprocess(collection, bitwise)
            images.append(image)
            masks.append(mask)
        x = torch.cat(images, dim=0)
        y = torch.cat(masks, dim=0)
        return x, y


In [6]:
# @title <h5>Utils</h5>

def create_history():
    history = {
        "mIOU": [],
        "Loss": [],
        "PixelAccuracy": []
    }
    return history

def create_meta():
    meta = {
        "zip_slice_index": 0,
        "epoch": 0,
        "batch": 0
    }
    return meta

def save_json(path, data):
    with open(path, 'w') as file:
        json.dump(data, file)

def load_json(path):
    with open(path, 'r') as file:
        data = json.load(file)
    return data

def is_loadable(path: tp.Optional[str]) -> bool:
    if path is None:
        return False
    if Path(path).exists() is False:
        return False
    return True

def backgrounds_was_loaded(path: str):
    folder = Path(path)
    if not folder.exists():
        return False
    if len(list(folder.glob("*"))) == 0:
        return False
    return True

def create_folder(path):
    folder = Path(path)
    folder.mkdir(parents=True, exist_ok=True)

def delete_folder(path_to_folder: str):
    folder = Path(path_to_folder)
    if folder.exists() is True:
        shutil.rmtree(folder)

def split_on_batches(paths, batch_size):
    n_batches = len(paths) // batch_size
    batches = []
    for i in range(n_batches):
        start = i * batch_size
        stop = start + batch_size
        batch = paths[start:stop]
        batches.append(batch)
    return batches

def extract_zip(path_to_zip :str,
    slice: tp.Optional[tp.Tuple[int, int]]=None) -> None:
    with ZipFile(path_to_zip, 'r') as zip:
        namelist = zip.namelist()
        if slice is None:
            start, stop = (0, len(namelist))
        else:
            start, stop = slice
        zip.extractall(members=namelist[start: stop])

def create_heatmap(image, mask, alpha: float=0.4):
    convert_to_pil = transforms.ToPILImage()
    image = convert_to_pil(image).convert("RGB")
    mask = convert_to_pil(mask).convert("RGB")
    heatmap = Image.blend(image, mask, alpha)
    return heatmap

def save_heatmap(path, heatmap):
    heatmap.save(path)

def create_heatmap_from_folder(path_to_input, path_to_output, 
    model, preprocessor, device, alpha: float=0.2):

    folder = Path(path_to_input)

    if not folder.exists():
        raise Exception("Folder does not exists")

    for i, image_path in enumerate(folder.glob("*")):
        with torch.no_grad():
            image = Image.open(image_path).convert('RGB')
            x = preprocessor.x_preprocessor(image)
            x = x.unsqueeze(0)
            x = x.to(device).float()
            pred = model(x)
            mask = convert_prediction_to_mask(pred[0]).float()
            heatmap = create_heatmap(x[0], mask, alpha)
            path = path_to_output + "/" + str(i) + ".jpg"
            save_heatmap(path, heatmap)


In [7]:
# @title <h5>Verbose</h5>

def get_statistics(subset: tp.Tuple[float]) -> tp.Tuple[float, float, float]:
    
    """
    Return median, mean, std for a subset.
    """

    median = torch.tensor(subset).median().item()
    mean = torch.tensor(subset).mean().item()
    std = torch.tensor(subset).std().item()
    return median, mean, std

def show_state(epoch_index: int, batch_index: int, 
    subset_loss: tp.List[float], subset_miou: tp.List[float], 
    subset_mpa: tp.List[float], delimiter: str="*") -> None:
    
    """
    Show value of metrics while training.
    MPA is Mean Pixel accuracy
    MIou is Mean Intersection over union
    """

    loss_median, loss_mean, loss_std = get_statistics(subset_loss)
    miou_median, miou_mean, miou_std = get_statistics(subset_miou)
    mpa_median, mpa_mean, mpa_std = get_statistics(subset_mpa)
    
    print(delimiter, delimiter, delimiter)
    print()
    print(f"Epoch       : {epoch_index}")
    print(f"Batch       : {batch_index}")
    print()
    print(f"LOSS Median : {loss_median}")
    print(f"LOSS Mean   : {loss_mean}")
    print(f"LOSS std    : {loss_std}")
    print()
    print(f"MIoU Median : {miou_median}")
    print(f"MIoU Mean   : {miou_mean}")
    print(f"MIoU Std    : {miou_std}")
    print()
    print(f"MPA Median : {mpa_median}")
    print(f"MPA Mean   : {mpa_mean}")
    print(f"MPA Std    : {mpa_std}")
    print()


In [8]:
# @title <h5>Metrics</h5>

def convert_prediction_to_mask(prediction, thr: float=0):
    mask = prediction > thr
    return mask

# INTERSECTION OVER UNION
def get_iou(prediction, target):

    if target.shape != prediction.shape:
        raise Exception('A target shape doesn`t match with a prediction shape')

    if target.dim() != 3:
        raise Exception(f'A target dim is {target.dim()}. Must be 3.')

    pred_copy = prediction.clone()
    pred_copy = convert_prediction_to_mask(pred_copy)
    
    target_copy = target.clone()
    target_copy = convert_prediction_to_mask(target_copy)

    intersection = torch.bitwise_and(target_copy, pred_copy).sum().item()
    union = torch.bitwise_or(target_copy, pred_copy).sum().item()
    
    if (target_copy.sum().item() == 0) and (pred_copy.sum().item() == 0):
        return 1
    elif union == 0:
        return 0

    return intersection / union

def get_mean_iou(predictions, targets):

    with torch.no_grad():
        if targets.shape != predictions.shape:
            raise Exception('A targets shape doesn`t match with a predictions shape')

        if targets.dim() != 4:
            raise Exception(f'A target dim is {targets.dim()}. Must be 4.')

        iou_sum = 0
        for i in range(targets.shape[0]):
            iou = get_iou(targets[i], predictions[i])
            iou_sum += iou
        mean_iou = iou_sum / targets.shape[0]
        return mean_iou

# PIXEL ACCURACY
def get_pixel_acc(prediction, target):

    if target.shape != prediction.shape:
        raise Exception('A target shape doesn`t match with a prediction shape')

    if target.dim() != 3:
        raise Exception(f'A target dim is {target.dim()}. Must be 3.')

    pred_copy = prediction.clone()
    pred_copy = convert_prediction_to_mask(pred_copy)

    target_copy = target.clone()
    target_copy = convert_prediction_to_mask(target_copy)

    same = (target_copy == pred_copy).sum().item()
    channels, height, width = target.shape
    area = height * width * channels
    acc = same / area
    return acc

def get_mean_pixel_acc(predictions, targets):

    with torch.no_grad():
        if targets.shape != predictions.shape:
            raise Exception('A targets shape doesn`t match with a predictions shape')

        if targets.dim() != 4:
            raise Exception(f'A target dim is {targets.dim()}. Must be 4.')

        acc_sum = 0
        for i in range(targets.shape[0]):
            acc = get_pixel_acc(targets[i], predictions[i])
            acc_sum += acc
        mean_acc = acc_sum / targets.shape[0]
        return mean_acc


In [9]:
# @title <h5>Checkpoint</h5>

def get_device():
    """
    Prefer to GPU
    """
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

def optimizer_to(optimizer, device):
    for param in optimizer.state.values():
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

def create_checkpoint(path, model, optimizer, loss_fn):
    to_save = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss_fn': loss_fn.state_dict(),
    }
    torch.save(to_save, path)

def save_model(path, model):
    to_save = {
        'model': model.state_dict(),
    }
    torch.save(to_save, path)

def load_checkpoint(path, device):
    return torch.load(path, map_location=device)


In [10]:
# @title <h5>Fit</h5>

def fit(batches: tp.List[tp.List], loader, preprocessor, model, optimizer, loss_fn, 
    epochs: int, ckpt_path: tp.Optional[str]=None, ckpt_per_iter: tp.Optional[int]=None,
    verbose_per_iter: tp.Optional[int]=None, history_path: tp.Optional[str]=None,
    meta_path: tp.Optional[str]=None, heatmap_input: tp.Optional[str]=None,
    heatmap_output: tp.Optional[str]=None, heatmap_per_iter: tp.Optional[int]=None):

    n_batches = len(batches)

    # GPU is preferred
    device = get_device()

    # Load model from checkpoint
    if is_loadable(ckpt_path):
        print("~ ~ ~ ~ ~ ~ ~ ~ ~ ~")
        print("Checkpoint was found")
        print("~ ~ ~ ~ ~ ~ ~ ~ ~ ~")
        print()
        checkpoint = load_checkpoint(ckpt_path, device)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        loss_fn.load_state_dict(checkpoint['loss_fn'])

    # Load history
    if is_loadable(history_path) and history_path is not None:
        history = load_json(history_path)
    elif not is_loadable(history_path) and history_path is not None:
        history = create_history()

    # Load meta
    if is_loadable(meta_path) and meta_path is not None:
        meta = load_json(meta_path)
    elif not is_loadable(meta_path) and meta_path is not None:
        meta = create_meta()

    # Read parameters
    start_epoch = meta["epoch"]
    start_batch = meta["batch"]

    # Switch to device
    model.to(device)
    model.train()
    optimizer_to(optimizer, device)

    for epoch in range(start_epoch, epochs):

        subset_loss = []
        subset_miou = []
        subset_mpa = []

        for i in range(start_batch, len(batches)):

            batch = batches[i]

            # Data
            torch.cuda.empty_cache()
            loaded = loader(batch)
            x, y = preprocessor(loaded[:])
            x = x.to(device).float()
            y = y.to(device).float()

            # Step
            pred = model(x)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            miou = get_mean_iou(pred, y)
            mpa = get_mean_pixel_acc(pred, y)

            # Add to buffer
            subset_loss.append(loss.item())
            subset_miou.append(miou)
            subset_mpa.append(mpa)

            # Add to history
            history['Loss'].append(loss.item())
            history['PixelAccuracy'].append(mpa)
            history['mIOU'].append(miou)

            # change meta batch
            meta["batch"] += 1

            # Verbose state of training
            if verbose_per_iter is not None:
                if (i+1) % verbose_per_iter == 0:
                    show_state(epoch, i+1, subset_loss, subset_miou, subset_mpa)
                    subset_loss = []
                    subset_miou = []
                    subset_mpa = []

            # Create checkpoint
            if (ckpt_path is not None) and (ckpt_per_iter is not None):
                if (i+1) % ckpt_per_iter == 0:
                    if meta_path is not None:
                        save_json(meta_path, meta)
                    if history_path is not None:
                        save_json(history_path, history)
                    create_checkpoint(ckpt_path, model, optimizer, loss_fn)
            
            if (heatmap_input is not None) and (heatmap_output is not None) and (heatmap_per_iter is not None):
                if (i+1) % heatmap_per_iter == 0:
                    if heatmap_output[-1] == "/":
                        current_output = f"{heatmap_output}{epoch}{i+1}"
                    else:
                        current_output = f"{heatmap_output}/{epoch}{i+1}"
                    create_folder(current_output)
                    create_heatmap_from_folder(heatmap_input, current_output,
                        model, preprocessor, device, alpha=0.4)
            
        # change meta epoch
        meta["epoch"] += 1
        meta["batch"] = 0
        start_batch = 0


In [11]:
# @title <h5>Unet</h5>

class ConvReLUBN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
        padding=0, stride=1, dilation=1, bias=False, separable=False):
        super(ConvReLUBN, self).__init__()

        groups = in_channels if separable else 1
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 
            padding=padding, stride=stride, dilation=dilation, bias=bias, groups=groups)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.norm(x)
        return x

class Exploration(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Exploration, self).__init__()
        self.sequence = nn.Sequential(
            ConvReLUBN(in_channels, out_channels, kernel_size=3, padding=1),
            ConvReLUBN(out_channels, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x = self.sequence(x)
        return x

class Unet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features: tp.List[int]=[64, 128, 256, 512]):
        super(Unet, self).__init__()

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        for feature in features:
            self.downs.append(Exploration(in_channels, feature))
            in_channels = feature

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(Exploration(feature*2, feature))

        self.bottleneck = Exploration(features[-1], features[-1]*2)
        self.output = nn.Conv2d(features[0], out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for index in range(0, len(self.ups), 2):
            x = self.ups[index](x)
            skip_connection = skip_connections[index//2]
            concat = torch.cat([x, skip_connection], dim=1)
            x = self.ups[index+1](concat)

        return self.output(x)


In [12]:
subset_0_4 = "/content/drive/MyDrive/Use_your_hands/Dataset/Ego2Hand_0_4.zip"
subset_5_10 = "/content/drive/MyDrive/Use_your_hands/Dataset/Ego2Hand_5_10.zip"
subset_11_16 = "/content/drive/MyDrive/Use_your_hands/Dataset/Ego2Hand_11_16.zip"
subset_17_21 = "/content/drive/MyDrive/Use_your_hands/Dataset/Ego2Hand_17_21.zip"
backgrounds = "/content/drive/MyDrive/Use_your_hands/Dataset/backgrounds.zip"

In [13]:
ckpt_path = "/content/drive/MyDrive/Unet/pretrained/UnetTranspose.pt"
meta_path = "/content/drive/MyDrive/Unet/meta/UnetTranspose.json"
history_path = "/content/drive/MyDrive/Unet/history/UnetTranspose.json"

In [14]:
heatmap_input = "/content/drive/MyDrive/Unet/heatmap/input"
heatmap_output = "/content/drive/MyDrive/Unet/heatmap/output/transpose"

In [22]:
extract_zip(backgrounds)

In [None]:
!unzip /content/drive/MyDrive/Dataset/HandOverFace.zip
!unzip /content/drive/MyDrive/Dataset/EgoHand.zip
!unzip /content/drive/MyDrive/Dataset/GTEA.zip

In [15]:
resize = (240, 320)
model = Unet()
loss_fn = nn.BCEWithLogitsLoss()
preprocessor = AugPreprocessor(resize)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)

In [None]:
if is_loadable(meta_path):
    meta = load_json(meta_path)
    start_epoch = meta["epoch"]
else:
    start_epoch = 0

for i in range(start_epoch, 4):

    if os.path.exists("Ego2Hand"):
        shutil.rmtree("Ego2Hand")
      
    torch.manual_seed(42)
    random.seed(42)
    
    extract_zip(subset_0_4, [8000*i, 8000*(i+1)])
    extract_zip(subset_5_10, [8000*i, 8000*(i+1)])
    extract_zip(subset_11_16, [8000*i, 8000*(i+1)])
    extract_zip(subset_17_21, [8000*i, 8000*(i+1)])
    
    gtea_paths = GTEAPaths("GTEA", seed=42)
    ego_hand_paths = EgoHandPaths("EgoHand", seed=42)
    ego_2_hand_paths = Ego2HandPaths("Ego2Hand", "backgrounds", seed=42)
    paths = AugmentationPaths([*gtea_paths[:], *ego_hand_paths[:], *ego_2_hand_paths[:]], seed=42)
    batches = split_on_batches(paths, 4)

    fit(batches, Loader, preprocessor, model, optimizer, loss_fn, (i+1), 
        ckpt_path, 25, 25, history_path, meta_path, heatmap_input,
        heatmap_output, 25)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
MPA Std    : 0.009471128694713116

* * *

Epoch       : 1
Batch       : 1300

LOSS Median : 0.05113939568400383
LOSS Mean   : 0.08826544135808945
LOSS std    : 0.0914829820394516

MIoU Median : 0.7191808223724365
MIoU Mean   : 0.7431752681732178
MIoU Std    : 0.12122315913438797

MPA Median : 0.9810123443603516
MPA Mean   : 0.9729200005531311
MPA Std    : 0.020992157980799675

* * *

Epoch       : 1
Batch       : 1325

LOSS Median : 0.06606404483318329
LOSS Mean   : 0.06852160394191742
LOSS std    : 0.02277274988591671

MIoU Median : 0.6448771357536316
MIoU Mean   : 0.6682133674621582
MIoU Std    : 0.12551912665367126

MPA Median : 0.97725909948349
MPA Mean   : 0.9735147953033447
MPA Std    : 0.014548328705132008

* * *

Epoch       : 1
Batch       : 1350

LOSS Median : 0.06380241364240646
LOSS Mean   : 0.06790410727262497
LOSS std    : 0.027253735810518265

MIoU Median : 0.7039154767990112
MIoU Mean   : 