In [1]:
import os
import torch
from torch import nn
import numpy as np
import pickle
import subprocess
from utils.show import (
    show_boundary,
    show_prediction_retlstm,
    show_raw
)
from utils.misc import (
    get_loaders_retlstm,
    send_model,
    load_model,
    get_and_load_model,
    get_and_load_ckp,
    comp
)
from utils.validate import validate_retlstm
from ret_lstm.models import (
    PatchModel,
    PatchModel2d,
    PatchModelPosEnc,
    PatchModelPosEnc2d,
    get_ConvModulePatch,
    RetLSTM,
    RetinaConv,
    RetinaConv2d,
    Block2d
)

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.deterministic = True          

## Send model to train

### Create PatchModel

In [12]:
# skip = 3
# conv = nn.Sequential(
#     Block2d(2, 8, (5, 15), skip, (1, 2)),
#     Block2d(8, 16, (5, 15), skip, (1, 2)),
#     Block2d(16, 32, (5, 15), skip, (1, 2)),
#     Block2d(32, 64, (5, 15), skip, (1, 2))
# )
# x = torch.randn(1, 2, 64, 496)
# inf = conv(x).swapaxes(1, 2).flatten(2).shape[-1]
# inf

1984

In [3]:
r = nn.ReLU(True)
mp = nn.MaxPool2d((1, 2))
inc = 1
kernel = (5, 15)
padding = (int((kernel[0]-1)/2), 7)
conv = nn.Sequential(
    nn.Conv2d(inc, 8, kernel, padding=padding),
    r,
    nn.Conv2d(8, 8, kernel, padding=padding),
    r,
    mp,
#     nn.Conv2d(8, 8, (1, 2), stride=(1, 2)),
    r,
    nn.Conv2d(8, 16, kernel, padding=padding),
    r,
    nn.Conv2d(16, 16, kernel, padding=padding),
    r,
    mp,
#     nn.Conv2d(16, 16, (1, 2), stride=(1, 2)),
    r,
    nn.Conv2d(16, 32, kernel, padding=padding),
    r,
    nn.Conv2d(32, 32, kernel, padding=padding),
    r,
    mp,
#     nn.Conv2d(32, 32, (1, 2), stride=(1, 2)),
    r,
    nn.Conv2d(32, 64, kernel, padding=padding),
    r,
    nn.Conv2d(64, 64, kernel, padding=padding),
    r,
    mp,
#     nn.Conv2d(64, 64, (1, 2), stride=(1, 2)),
#     r,
#     nn.Conv2d(64, 128, kernel, padding=padding),
#     r,
#     nn.Conv2d(128, 128, kernel, padding=padding),
#     r,
#     mp,
)
x = torch.randn(16, inc, 64, 496)
inf = conv(x).swapaxes(1, 2).flatten(2).shape[-1]
inf

1984

In [4]:
fc = nn.Linear(inf, 8)
# pm = load_model("pm_k15-5_same")
# conv = pm.conv
patchmodel = PatchModelPosEnc2d(
    conv=conv,
    fc=fc,
    pos_enc="cc",
    cc_type="center",
    sum_coord="mul"
)
# patchmodel = PatchModel2d(
#     conv=conv,
#     fc=fc,
# )
send_model(patchmodel, "pm_k15-5_centerm_same")
# send_model(patchmodel, "pm_k15-5_mm_same")

626320
model sent


### Create LSTM

In [13]:
pm = load_model("pm_k15-nb")

In [14]:
n_hidden = 128
# patchmodel = PatchModelPosEnc(
#     conv=pm.conv,
#     fc=nn.Sequential(nn.Linear(256, 256), nn.ReLU(True)),
#     pos_enc="cc",
#     cc_type="center",
#     neighbors=13
# )
patchmodel = PatchModel2d(
    conv=pm.conv,
    fc=nn.Linear(pm.fc.in_features, n_hidden),
)

In [16]:
lstm = RetLSTM(
    patchmodel=patchmodel,
    fc=nn.Linear(n_hidden, 8),
    hidden_size=n_hidden,
    forget_bias=10,
)
send_model(lstm, "lstm_k15-nb_nh128", init="def")

287440
model sent


### Create RetinaConv

In [5]:
r = nn.ReLU(True)
kernel = (1, 5)
padding = (0, int((kernel[1] - 1)/2))
conv = nn.Sequential(
    nn.Conv2d(1, 8, kernel, padding=padding),
    r,
    nn.Conv2d(8, 16, kernel, padding=padding),
    r,
    nn.Conv2d(16, 32, kernel, padding=padding),
    r,
    nn.Conv2d(32, 64, kernel, padding=padding),
    r,
#     nn.Conv2d(16, 32, kernel, padding=padding)
)
x = torch.randn(1, 1, 32, 64)
inf2 = conv(x).swapaxes(1, 2).flatten(2).shape[-1]

In [9]:
pm = get_and_load_model("pm_k15-nb")
inf1 = pm.fc.in_features
# patchmodel = PatchModelPosEnc2d(
#     conv=pm.conv,
#     fc=nn.Identity(),
#     pos_enc="cc",
#     cc_type="center",
# )
patchmodel = PatchModel2d(
    conv=pm.conv,
    fc=nn.Identity()
)
fc = nn.Linear(inf2, 64)
rc = RetinaConv2d(patchmodel=patchmodel, conv=conv, fc1=fc, fc2=pm.fc)
send_model(rc, "rc_k15-nb_5", init="def")

418288
model sent


## Get model to visualize

In [2]:
name = "pm_k15-5_center_same"
pm = get_and_load_model(name)
get_and_load_ckp(pm, f"pkl/{name}/lr5.e-04_bs4_pw64_", 100)

>> checkpoint loaded


In [3]:
import albumentations as A

In [4]:
train_transf = A.HorizontalFlip()

In [5]:
patch_width = 496
train_dl, valid_dl, mean_std = get_loaders_retlstm(
    "../../corrected_ds",
    patch_width, 32, train_transf, True, True, True, False, 2
)
x, y, corner, mask = next(iter(valid_dl))

In [68]:
loss, dice, ce, mad = validate_retlstm(pm, valid_dl, nn.L1Loss(), True, True, torch.ones(8), mean_std, 9)

                                                                                          

In [69]:
print(ce.mean())

2.254


In [9]:
pred, target = show_prediction_retlstm(pm, 0, x, y, corner, mean_std)

In [10]:
pred.show()

In [11]:
target.show()

# Filter Bank

In [None]:
conv = nn.Sequential(
    nn.Conv2d(inc, 8, (5, 15), padding=(2, 7)),
    r,
    nn.Conv2d(8, 8, (5, 15), padding=(2, 7)),
    r,
    mp,
#     nn.Conv2d(8, 8, (1, 2), stride=(1, 2)),
    r,
    nn.Conv2d(8, 16, (1, 15), padding=(0, 7)),
    r,
    nn.Conv2d(16, 16, (1, 15), padding=(0, 7)),
    r,
    mp,
#     nn.Conv2d(16, 16, (1, 2), stride=(1, 2)),
    r,
    nn.Conv2d(16, 32, (1, 15), padding=(0, 7)),
    r,
    nn.Conv2d(32, 32, (1, 15), padding=(0, 7)),
    r,
    mp,
#     nn.Conv2d(32, 32, (1, 2), stride=(1, 2)),
    r,
    nn.Conv2d(32, 64, (1, 15), padding=(0, 7)),
    r,
    nn.Conv2d(64, 64, (1, 15), padding=(0, 7)),
    r,
    mp,
#     nn.Conv2d(64, 64, (1, 2), stride=(1, 2)),
    r,
#     nn.Conv2d(64, 128, kernel, padding=padding),
#     r,
#     nn.Conv2d(128, 128, kernel, padding=padding),
#     r,
#     mp,
)
x = torch.randn(16, inc, 64, 496)
inf = conv(x).swapaxes(1, 2).flatten(2).shape[-1]
inf