In [2]:
# Generic/Built-in
from typing import Callable, Generator, List, TypeVar, Tuple

# Libs
import time
import math
from tqdm.notebook import tqdm
import dask
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint_sequential
import tntorch as tnt
import itertools
from tensorly.tenalg import tensor_dot

# Custom
from unet.unet import UNet
from autoencoder.cnnautoencoder_asym import AE
from preprocessing import LazyDataset, PatchedImage, PatchCoord
from imseg_wrapper import FedSeg

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [2]:
fedseg = FedSeg(
    unet_channels=3,
    unet_classes=3,
    image_height=8,
    image_width=4,
    patch_height=4,
    patch_width=4,
    tt_decomp_max_rank=3,
    tensor_batch_size=1000,
    AE=AE,
    AD=AE,
    unet=UNet
)

lazy_train = LazyDataset("data/stare/examples/stare/datasets/partition3/train")
lazy_test = LazyDataset("data/stare/examples/stare/datasets/partition3/evaluate")

lazy_train_dataloader = DataLoader(
    lazy_train, batch_size=4, collate_fn=lambda x: tuple(x)
)
lazy_test_dataloader = DataLoader(
    lazy_test, batch_size=4, collate_fn=lambda x: tuple(x)
)

In [3]:
# testing pytest

fedseg = FedSeg(
    unet_channels=3,
    unet_classes=3,
    image_height=9,
    image_width=9,
    patch_height=3,
    patch_width=3,
    AE=AE,
    AD=AE,
    unet=UNet,
)

fedseg.unet_input = 33

fedseg.create_adapter_layers()
tensor = torch.randint(0, 255, (1, 3, 3, 3)) / 255

encoder_logits = fedseg.autoencoder.forward(tensor)
print(encoder_logits.shape)
unet_logits = fedseg.unet(encoder_logits)
print(unet_logits.shape)
decoder_logits = fedseg.autodecoder(unet_logits)


Starting the encoder layers... Original input height is 3 and input width is 3
Finished adding max pool layers

For height, with padding: 0, dilation: 1, kernel_size[0]: 3, and stride:1, the output height is 1, and the difference is 2
For width, with padding: 0, dilation: 1, kernel_size[0]: 3, and stride:1, the output width is 1, and the the difference is 2
num convs: 0

Amt height to force shrink: 1

Continuing encoder layers... -1 image shrinking layers...
For height, with padding: 0, dilation: 1, kernel_size[0]: 2, and stride:1, the output height is 2, and the difference is 1
For width, with padding: 0, dilation: 1, kernel_size[0]: 2, and stride:1, the output width is 2, and the the difference is 1
Round 0: cur image height: 2, cur image width: 2

Starting the decoder layers... input height to decoder layers is 2 and input width is 2
Starting decoder ConvTranspose2D layers...Image height: 2, Input width;2
Round 0 - Count 1: Cur image sizes after ConvTranspose2D - cur image height: 4

In [33]:
from torchsummary import summary

autoencoder = AE(
    input_height=91,
    input_width=91,
    encoded_size=3,
    decoded_size=37,
    in_channels=3,
    out_channels=3,
    num_output_features=3,
    kernel_size=(5, 5),
)

summary(autoencoder, (3,91,91))

Starting the encoder layers... Original input height is 91 and input width is 91
For height, with padding: 0, dilation: 1, kernel_size[0]: 5, and stride:1, the output height is 87, and the difference is 4
For width, with padding: 0, dilation: 1, kernel_size[0]: 5, and stride:1, the output width is 87, and the the difference is 4
Round 0 - Count 1: cur image height: 87, cur image width: 87
For height, with padding: 0, dilation: 1, kernel_size[0]: 5, and stride:1, the output height is 83, and the difference is 4
For width, with padding: 0, dilation: 1, kernel_size[0]: 5, and stride:1, the output width is 83, and the the difference is 4
Round 0 - Count 2: cur image height: 83, cur image width: 83
Round 0 - Count 2: cur image height: 41, cur image width: 41
For height, with padding: 0, dilation: 1, kernel_size[0]: 5, and stride:1, the output height is 37, and the difference is 4
For width, with padding: 0, dilation: 1, kernel_size[0]: 5, and stride:1, the output width is 37, and the the di

OverflowError: integer division result too large for a float

In [5]:
# forward an image
image, mask = next(fedseg.images(lazy_train_dataloader))
image.load()
new_image = fedseg.forward(image)

Generating patches...

PATCH:torch.Size([4, 4, 3])
Flattening patch...
Calculating local maps...
Calculating global maps...
Converting global map to TT...
Padding tensors......
Forward pass...
Starting the encoder layers... Original input height is 3 and input width is 3
Finished adding max pool layers

For height, with padding: 0, dilation: 1, kernel_size[0]: 3, and stride:1, the output height is 1, and the difference is 2
For width, with padding: 0, dilation: 1, kernel_size[0]: 3, and stride:1, the output width is 1, and the the difference is 2
num convs: 0

Amt height to force shrink: 1

Continuing encoder layers... -1 image shrinking layers...
For height, with padding: 0, dilation: 1, kernel_size[0]: 2, and stride:1, the output height is 2, and the difference is 1
For width, with padding: 0, dilation: 1, kernel_size[0]: 2, and stride:1, the output width is 2, and the the difference is 1
Round 0: cur image height: 2, cur image width: 2

Starting the decoder layers... input height to



TENSOR COUNT:  10
TENSOR COUNT:  20
TENSOR COUNT:  30
TENSOR COUNT:  40
TENSOR COUNT:  50
stacked tensor shape:torch.Size([3, 4, 4])
forwarded tensor shape:torch.Size([4, 4, 3])
stacked_tensor.shape:  torch.Size([4, 4, 3])
patch counter: 1
PATCH:torch.Size([4, 4, 3])
Flattening patch...
Calculating local maps...
Calculating global maps...
Converting global map to TT...
Padding tensors......
Forward pass...
TENSOR COUNT:  10
TENSOR COUNT:  20
TENSOR COUNT:  30
TENSOR COUNT:  40
TENSOR COUNT:  50
stacked tensor shape:torch.Size([3, 4, 4])
forwarded tensor shape:torch.Size([4, 4, 3])
stacked_tensor.shape:  torch.Size([4, 4, 3])
patch counter: 2
number of patches 2
shape of final patch image torch.Size([8, 4, 3])


## Training Loop

In [5]:
def pixel_accuracy(output, mask):
    print(f"pixel input mask shape:{mask.shape}")
    print(f"pixel input output shape:{output.shape}")
    with torch.no_grad():
        print(f"softmax results: {F.softmax(output, dim=1).shape}")
        output = torch.argmax(F.softmax(output, dim=1), dim=0)
        print(f"pixel mask shape:{mask.shape}")
        print(f"pixel output shape:{output.shape}")
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

In [6]:
def mIoU(pred_mask, mask, smooth=1e-10, n_classes=2):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes):  # loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0:  # no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                print(true_class.shape)
                print(true_label.shape)
                intersect = (
                    torch.logical_and(true_class, true_label).sum().float().item()
                )
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union + smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def fit(
    epochs,
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    patch=False,
):
    torch.cuda.empty_cache()
    train_losses = []
    test_losses = []
    val_iou = []
    val_acc = []
    train_iou = []
    train_acc = []
    # lrs = []
    min_loss = np.inf
    decrease = 1
    not_improve = 0

    model
    fit_time = time.time()
    for e in range(epochs):
        since = time.time()
        running_loss = 0
        iou_score = 0
        accuracy = 0
        # training loop
        model.train()
        for i, data in enumerate(tqdm(fedseg.images(train_loader))):
            # img is in the format[C,H,W]
            image, mask = data
            output = model(image)
            mask_array = mask.retrieve(PatchCoord(0, 0), PatchCoord(700, 605))
            chw_output = torch.permute(output, (2, 0, 1))
            # because the mask is of type torch.ByteTensor, we have to cast it to float
            mask_long = mask_array.long()

            loss = criterion(chw_output.to(device), mask_long.to(device))  # to fix this

            # evaluation metrics

            accuracy += pixel_accuracy(output, mask)
            print(f"accuracy: {accuracy}")
            # backward
            loss.backward()
            optimizer.step()  # update weight
            optimizer.zero_grad()  # reset gradient

            # step the learning rate
            scheduler.step()

            running_loss += loss.item()
            torch.save(
                {
                    "epoch": e,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": loss,
                },
                f"model_epoch{e}.pt",
            )
        else:
            print("Evaluating...")
            model.eval()
            test_loss = 0
            test_accuracy = 0
            val_iou_score = 0
            # validation loop
            with torch.no_grad():

                for i, data in enumerate(tqdm(fedseg.images(val_loader))):
                    # img is in the format[C,H,W]
                    image, mask = data
                    image.load()
                    output = model(image)
                    print("output image shape", image.shape)
                    mask_array = mask.retrieve(
                        PatchCoord(0, 0), PatchCoord(4, 8)
                    )  # W, H
                    chw_output = torch.permute(output, (2, 0, 1)) 
                    # because the mask is of type torch.ByteTensor, we have to cast it to float
                    mask_long = mask_array.long()

                    # evaluation metrics
                    # val_iou_score +=  mIoU(output, mask)
                    print(f"test accuracy output shape:{output.shape}")
                    print(f"test accuracy mask shape:{mask.shape}")
                    test_accuracy += pixel_accuracy(output, mask)
                    # loss
                    loss = criterion(output, mask)
                    test_loss += loss.item()

            # calculatio mean for each batch
            train_losses.append(running_loss / len(train_loader))
            test_losses.append(test_loss / len(val_loader))

            if min_loss > (test_loss / len(val_loader)):
                print(
                    "Loss Decreasing.. {:.3f} >> {:.3f} ".format(
                        min_loss, (test_loss / len(val_loader))
                    )
                )
                min_loss = test_loss / len(val_loader)
                decrease += 1
                if decrease % 5 == 0:
                    print("saving model...")
                    # torch.save(model, 'Unet-Mobilenet_v2_mIoU-{:.3f}.pt'.format(val_iou_score/len(val_loader)))

            if (test_loss / len(val_loader)) > min_loss:
                not_improve += 1
                min_loss = test_loss / len(val_loader)
                print(f"Loss Not Decrease for {not_improve} time")
                if not_improve == 7:
                    print("Loss not decrease for 7 times, Stop Training")
                    break

            train_acc.append(accuracy / len(train_loader))
            val_acc.append(test_accuracy / len(val_loader))
            print(
                "Epoch:{}/{}..".format(e + 1, epochs),
                "Train Loss: {:.3f}..".format(running_loss / len(train_loader)),
                "Val Loss: {:.3f}..".format(test_loss / len(val_loader)),
                "Train Acc:{:.3f}..".format(accuracy / len(train_loader)),
                "Val Acc:{:.3f}..".format(test_accuracy / len(val_loader)),
                "Time: {:.2f}m".format((time.time() - since) / 60),
            )

    history = {
        "train_loss": train_losses,
        "val_loss": test_losses,
        "train_acc": train_acc,
        "val_acc": val_acc,
    }
    print("Total time: {:.2f} m".format((time.time() - fit_time) / 60))
    return history

In [8]:
max_lr = 1e-3
epoch = 3
weight_decay = 1e-4
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.enabled = True
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(fedseg.parameters(), lr=max_lr, weight_decay=weight_decay)
sched = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr, epochs=epoch, steps_per_epoch=len(lazy_train_dataloader)
)

history = fit(
    epoch,
    fedseg,
    lazy_train_dataloader,
    lazy_test_dataloader,
    criterion,
    optimizer,
    sched,
)

0it [00:00, ?it/s]

Generating patches...

PATCH:torch.Size([4, 4, 3])
Flattening patch...
Calculating local maps...
Calculating global maps...
Converting global map to TT...
Padding tensors......
Forward pass...
Starting the encoder layers... Original input height is 3 and input width is 3
Finished adding max pool layers

For height, with padding: 0, dilation: 1, kernel_size[0]: 3, and stride:1, the output height is 1, and the difference is 2
For width, with padding: 0, dilation: 1, kernel_size[0]: 3, and stride:1, the output width is 1, and the the difference is 2
num convs: 0

Amt height to force shrink: 1

Continuing encoder layers... -1 image shrinking layers...
For height, with padding: 0, dilation: 1, kernel_size[0]: 2, and stride:1, the output height is 2, and the difference is 1
For width, with padding: 0, dilation: 1, kernel_size[0]: 2, and stride:1, the output width is 2, and the the difference is 1
Round 0: cur image height: 2, cur image width: 2

Starting the decoder layers... input height to



TENSOR COUNT:  10
TENSOR COUNT:  20
TENSOR COUNT:  30
TENSOR COUNT:  40
TENSOR COUNT:  50
stacked tensor shape:torch.Size([3, 4, 4])
forwarded tensor shape:torch.Size([4, 4, 3])
stacked_tensor.shape:  torch.Size([4, 4, 3])
patch counter: 1
PATCH:torch.Size([4, 4, 3])
Flattening patch...
Calculating local maps...
Calculating global maps...
Converting global map to TT...
Padding tensors......
Forward pass...
TENSOR COUNT:  10
TENSOR COUNT:  20
TENSOR COUNT:  30
TENSOR COUNT:  40
TENSOR COUNT:  50
stacked tensor shape:torch.Size([3, 4, 4])
forwarded tensor shape:torch.Size([4, 4, 3])
stacked_tensor.shape:  torch.Size([4, 4, 3])
patch counter: 2
number of patches 2
2
1
shape of final patch image torch.Size([8, 4, 3])


ValueError: Expected input batch_size (3) to match target batch_size (1).

TO DO: 

- to fix self.image_height = image_height
- to get mask from PatchedImage
- to test out py file 
- to do unit test