# Train Model

## Load base model

In [None]:
from fastai.vision.all import *

In [None]:
def label_func(x):
    return str(x['file']).replace("images", "masks/obs0_rep0").replace("image","mask")

In [None]:
def acc_seg(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()

def multi_dice(input:Tensor, targs:Tensor, class_id=0, inverse=False):
    n = targs.shape[0]
    input = input.argmax(dim=1).view(n,-1)
    # replace all with class_id with 1 all else with 0 to have binary case
    output = (input == class_id).float()
    # same for targs
    targs = (targs.view(n,-1) == class_id).float()
    if inverse:
        output = 1 - output
        targs = 1 - targs
    intersect = (output * targs).sum(dim=1).float()
    union = (output+targs).sum(dim=1).float()
    res = 2. * intersect / union
    res[torch.isnan(res)] = 1
    return res.mean()

def diceComb(input:Tensor, targs:Tensor):
    return multi_dice(input, targs, class_id=0, inverse=True)
def diceLV(input:Tensor, targs:Tensor):
    return multi_dice(input, targs, class_id=1)
def diceMY(input:Tensor, targs:Tensor):
    return multi_dice(input, targs, class_id=2)

In [None]:
trainedModel = load_learner("../inputs/models/kaggle-ukbb-base-fastai2.pkl", cpu=False)

## Prepare data loading

In [None]:
all_files = sorted(glob.glob("../inputs/cmr-cine-sscrofa/data/png/images/*.png"))

In [None]:
df = pd.DataFrame({"file":all_files})
df

In [None]:
df = df.assign(
    id=lambda x: [z.split("/")[-1].split("_")[0] for z in x['file']],
    frame=lambda x: [int(z.split("/")[-1].split("_")[2][5:8]) for z in x['file']],
)
df

In [None]:
sets = pd.read_csv("../inputs/training/cmr-cine-sscrofa.sets.tsv", sep="\t")

In [None]:
df = pd.merge(df, sets)

In [None]:
esed = pd.read_csv("../inputs/cmr-cine-sscrofa/data/metadata/obs0_rep0.tsv", sep="\t")
esed = esed.melt(id_vars=['id'], value_vars=['es', 'ed'], var_name="phase", value_name="frame")
esed

In [None]:
df = pd.merge(df, esed)

In [None]:
df = df.assign(is_valid=lambda x: x['set']=="val")

In [None]:
df.set.value_counts()

In [None]:
train_val = df[df.set != "test"]

In [None]:
train_val.set.value_counts()

In [None]:
train_val.is_valid.value_counts()

In [None]:
heart = DataBlock(blocks=(ImageBlock, MaskBlock(codes = np.array(["background","left_ventricle","myocardium"]))),
        get_x=ColReader("file"),
        get_y = label_func,
        splitter = ColSplitter(col="is_valid"),
        item_tfms=Resize(512, method='crop'),
        batch_tfms=aug_transforms(do_flip=True,max_rotate=90,max_lighting=.4,max_zoom=1.2,size=256))

In [None]:
import ctypes

magma_path = '/tank/home/ankenbrand/miniconda3/lib/libmagma.so'
magma_path = '/tank/home/ankenbrand/miniconda3/pkgs/magma-2.5.4-h6103c52_2/lib/libmagma.so'
libmagma = ctypes.cdll.LoadLibrary(magma_path)
libmagma.magma_init()

In [None]:
dls = heart.dataloaders(train_val, bs=16)

In [None]:
dls.show_batch()

## Make predictions with base model

In [None]:
img = Image.open("../inputs/cmr-cine-sscrofa/data/png/images/A05_slice004_frame029-image.png").resize((256,256))

In [None]:
pred, bla, blub = trainedModel.predict("../inputs/cmr-cine-sscrofa/data/png/images/A05_slice004_frame029-image.png")

In [None]:
plt.imshow(pred)

In [None]:
plt.imshow(img, cmap="bone")
plt.imshow(pred, alpha=.5)

In [None]:
trainedModel.dls = dls

In [None]:
trainedModel.show_results()

## Retrain model

In [None]:
trainedModel.path = Path("../model")

In [None]:
trainedModel.add_cbs([CSVLogger(append=True)])

In [None]:
trainedModel.freeze()

In [None]:
trainedModel.lr_find()

In [None]:
for i in range(10):
    trainedModel.fit_one_cycle(10, lr_max=1e-4)
    trainedModel.save("{}-epochs".format(10*i+10))

In [None]:
trainedModel.show_results()

In [None]:
trainedModel.unfreeze()

In [None]:
trainedModel.lr_find()

In [None]:
for i in range(10):
    trainedModel.fit_one_cycle(10, lr_max=1e-5)
    trainedModel.save("{}-epochs-unfrozen".format(10*i+10))

In [None]:
a,b,c = trainedModel.get_preds(with_input = True)

In [None]:
plt.imshow(a[9].permute(1,2,0))

In [None]:
a[9].permute(1,2,0).max()

In [None]:
fig, ax = plt.subplots(figsize=(12,12))
ax.imshow(trainedModel.predict("../inputs/cmr-cine-sscrofa/data/png/images/A05_slice005_frame010-image.png")[0])

In [None]:
trainedModel.remove_cb(CSVLogger)
trainedModel.export("100-epochs-unfrozen.pkl")