In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd


from data_parameters import DataParamMode, DataParamOptim
from data_parameters import DataParameterManager

# 1D Training

In [3]:
CLASSES = ['background','ventricle', 'myocard', 'vein']

class MiniDataset(Dataset):
    def __init__(self):
        self.len = 10
        self.images = -2*torch.ones((self.len, 35))
        self.labels = torch.stack(
            [
                torch.ones((self.len)), # background
                torch.zeros((self.len)), # venctricle
                torch.zeros((self.len)), # myocard
                torch.zeros((self.len)), # vein
            ],
            dim=-1
        ).long()
        # self.labels[0,0] = 1
        self.disturbed_idxs = []

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        if idx in self.disturbed_idxs:
            label = self.labels[idx]*10
        else:
            label = self.labels[idx]

        return {'d_idx': idx, 'image': self.images[idx], 'label': label}

    def set_disturbed_idxs(self, idxs):
        self.disturbed_idxs = idxs

        

class MiniNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.layer = torch.nn.AdaptiveAvgPool1d(1)
        self.linear = torch.nn.Linear(1,len(CLASSES), bias=False)

    def forward(self, _input):
        _output = self.layer(_input)
        _output = self.linear(_output)**2
        return _output

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

config = dotdict({
    # Data parameter config
    'data_param_mode': DataParamMode.COMBINED_INSTANCE_CLASS_PARAMS,
    'init_class_param': 0.01, 
    'lr_class_param': 0.1,
    'init_inst_param': 1.0, 
    'lr_inst_param': 0.1,
    'wd_inst_param': 0.0,
    'wd_class_param': 0.0,
    
    'skip_clamp_data_param': False,
    'clamp_sigma_min': np.log(1/20),
    'clamp_sigma_max': np.log(20),
    'optim_algorithm': DataParamOptim.ADAM,
    'optim_options': dict(
        # momentum=.9
        # betas=(0.9, 0.999)
    )
})


# 1D only-inst parameters experiment

torch.manual_seed(0)
# random.seed(0)
np.random.seed(0)

net = MiniNet()
data = MiniDataset()

train_dataloader = DataLoader(data, 3, shuffle=True)
# Problem: SGD, w/o momentum, when disturbed parameters are in unbalanced batches (varying number of disturbed parameters in minibatch)
# independent of param group definition

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=.01)
data.set_disturbed_idxs([0,1,7]) #1,2,5,6,7,8,9
print(config.optim_algorithm)

epochs = 100

for mode in DataParamMode:
    print(mode)
    config.data_param_mode = mode
    dpm = DataParameterManager(instance_keys=range(len(data)), class_keys=CLASSES, config=config, device='cpu')

    for epx in range(epochs):
        # print("epoch", epx)
        for b_idx, sample in enumerate(train_dataloader):
            # if fail_var():
                # raise(ValueError(f"err at {epx} {b_idx}"))
            # print("batch", b_idx)
            image, label = sample['image'], sample['label']

            logits = net(image)
            loss = dpm.do_basic_train_step(
                criterion, 
                logits, 
                label, 
                optimizer, 
                inst_keys=sample['d_idx'].tolist(),
                scaler=None)
    print(f"loss={loss}")

    with torch.no_grad():
        df = pd.DataFrame(dpm.get_data_parameters_dict())
        for ridx, row in df.iterrows():
            for cidx, elem in row.iteritems():
                df[cidx][ridx] = np.exp(elem).item()

    pd.options.display.float_format = '{:.2f}'.format
    display(df)
    print()
    print()


DataParamOptim.ADAM
DataParamMode.ONLY_INSTANCE_PARAMS
Initialized instance data parameters with: 1.0
loss=0.00016682221030350775


Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,0.33,0.33,3.34,3.34,3.34,3.34,3.34,0.33,3.34,3.34




DataParamMode.ONLY_CLASS_PARAMS
Initialized class data parameters with: 0.01
loss=2.4635424613952637


Unnamed: 0,background,ventricle,myocard,vein
0,0.88,1.01,1.01,1.01




DataParamMode.COMBINED_INSTANCE_CLASS_PARAMS
Initialized combined data parameters with: 1.01
loss=5.69614485357306e-06


Unnamed: 0,0,1,2,3,4,5,6,7,8,9
background,0.46,0.46,4.62,4.65,4.62,4.63,4.62,0.46,4.62,4.63
ventricle,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75
myocard,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75
vein,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75,2.75




DataParamMode.SEPARATE_INSTANCE_CLASS_PARAMS
Initialized instance data parameters with: 1.0
Initialized class data parameters with: 0.01
loss=2.28912881539145e-06


Unnamed: 0,dp_inst:0,dp_inst:1,dp_inst:2,dp_inst:3,dp_inst:4,dp_inst:5,dp_inst:6,dp_inst:7,dp_inst:8,dp_inst:9,dp_class:background,dp_class:ventricle,dp_class:myocard,dp_class:vein
0,0.52,0.52,6.26,6.26,6.26,6.26,6.26,0.52,6.31,6.26,0.12,1.01,1.01,1.01




DataParamMode.DISABLED
loss=1.6860144138336182






# 2D Training

In [4]:
CLASSES = ['background','ventricle','myocard', 'aorta', 'splenic_vein']

class MiniDataset2D(Dataset):
    def __init__(self):
        self.len = 9
        SPATIAL_DIM = 4
        self.images = torch.zeros((self.len, 1, SPATIAL_DIM,SPATIAL_DIM))
        self.images[:self.len//3] = 1*torch.ones((self.len//3, 1, SPATIAL_DIM,SPATIAL_DIM))
        self.images[self.len//3:2*self.len//3] = 2*torch.ones((2*self.len//3-self.len//3, 1, SPATIAL_DIM,SPATIAL_DIM))
        self.images[2*self.len//3:] = 3*torch.ones((self.len-2*self.len//3, 1, SPATIAL_DIM,SPATIAL_DIM))

        self.labels = torch.zeros((self.len, SPATIAL_DIM,SPATIAL_DIM))
        self.labels[:self.len//3] = 1*torch.ones((self.len//3, SPATIAL_DIM,SPATIAL_DIM))
        self.labels[self.len//3:2*self.len//3] = 2*torch.ones((2*self.len//3-self.len//3, SPATIAL_DIM,SPATIAL_DIM))
        self.labels[2*self.len//3:] = 3*torch.ones((self.len-2*self.len//3, SPATIAL_DIM,SPATIAL_DIM))

        # lbls = 3*torch.ones((self.len, SPATIAL_DIM, SPATIAL_DIM)).long()
        # 13.78 1.11 6.11 0.37 14.47 11.88

        # Labels have Lx4x4x3 shape (onehot of 4x4 segmentation)
        self.labels = torch.nn.functional.one_hot(self.labels.long(), num_classes=len(CLASSES))
        self.disturbed_idxs = []

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        if idx in self.disturbed_idxs:
            label = self.labels[idx]
            # label[..., -1] = 1
            # label[..., 1] = 0
            # if idx == 0:
            #     label[..., -2] = 1
            label[idx] = 0
        else:
            label = self.labels[idx]

        return {'d_idx': idx, 'image': self.images[idx], 'label': label}

    def set_disturbed_idxs(self, idxs):
        self.disturbed_idxs = idxs



class MiniNet2D(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = torch.nn.Conv2d(1, len(CLASSES), (1,1), bias=False)

    def forward(self, _input):
        _output = self.conv(_input)**2
        return _output.permute(0,2,3,1)


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

config = dotdict({
    # Data parameter config
    'data_param_mode': DataParamMode.ONLY_INSTANCE_PARAMS,
    'init_class_param': 0.01, 
    'lr_class_param': 0.1,
    'init_inst_param': 1.0, 
    'lr_inst_param': 0.1,
    'wd_inst_param': 0.0,
    'wd_class_param': 0.0,
    
    'skip_clamp_data_param': False,
    'clamp_sigma_min': np.log(1/20),
    'clamp_sigma_max': np.log(20),
    'optim_algorithm': DataParamOptim.ADAM,
    'optim_options': dict(
        # momentum=.9
        # betas=(0.9, 0.999)
    )
})

# torch.manual_seed(0)
# # random.seed(0)
# np.random.seed(0)


net = MiniNet2D()
data2D = MiniDataset2D()

train_dataloader = DataLoader(data2D, 3, shuffle=True)
# Problem: SGD, w/o momentum, when disturbed parameters are in unbalanced batches (varying number of disturbed parameters in minibatch)
# independent of param group definition

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=.01)
data.set_disturbed_idxs([0,3,6]) #1,2,5,6,7,8,9
print(config.optim_algorithm)

epochs = 100

for mode in DataParamMode:
    print(mode)
    config.data_param_mode = mode
    dpm = DataParameterManager(instance_keys=range(len(data)), class_keys=CLASSES, config=config, device='cpu')

    for epx in range(epochs):
        # print("epoch", epx)
        for b_idx, sample in enumerate(train_dataloader):
            # if fail_var():
                # raise(ValueError(f"err at {epx} {b_idx}"))
            # print("batch", b_idx)
            image, label = sample['image'], sample['label']

            logits = net(image)
            loss = dpm.do_basic_train_step(
                criterion, 
                logits, 
                label, 
                optimizer, 
                inst_keys=sample['d_idx'].tolist(),
                scaler=None)
    print(loss)

    with torch.no_grad():
        df = pd.DataFrame(dpm.get_data_parameters_dict())
        for ridx, row in df.iterrows():
            for cidx, elem in row.iteritems():
                df[cidx][ridx] = np.array(elem).item()#np.exp(elem).item()

    pd.options.display.float_format = '{:.2f}'.format
    display(df)
    print()
    print()

DataParamOptim.ADAM
DataParamMode.ONLY_INSTANCE_PARAMS
Initialized instance data parameters with: 1.0
0.133575439453125


Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,1.39,1.39,1.4,2.37,2.37,2.37,2.0,2.0,2.0,1.0




DataParamMode.ONLY_CLASS_PARAMS
Initialized class data parameters with: 0.01
0.03409909829497337


Unnamed: 0,background,ventricle,myocard,aorta,splenic_vein
0,0.01,-3.0,-3.0,-0.01,0.01




DataParamMode.COMBINED_INSTANCE_CLASS_PARAMS
Initialized combined data parameters with: 1.01
0.002429551212117076


Unnamed: 0,0,1,2,3,4,5,6,7,8,9
background,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.01
ventricle,-3.0,-3.0,-3.0,1.01,1.01,1.01,1.01,1.01,1.01,1.01
myocard,1.01,1.01,1.01,-2.38,-2.38,-2.38,1.01,1.01,1.01,1.01
aorta,1.01,1.01,1.01,1.01,1.01,1.01,-0.38,-0.38,-0.38,1.01
splenic_vein,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.01,1.01




DataParamMode.SEPARATE_INSTANCE_CLASS_PARAMS
Initialized instance data parameters with: 1.0
Initialized class data parameters with: 0.01
0.032711710780858994


Unnamed: 0,dp_inst:0,dp_inst:1,dp_inst:2,dp_inst:3,dp_inst:4,dp_inst:5,dp_inst:6,dp_inst:7,dp_inst:8,dp_inst:9,dp_class:background,dp_class:ventricle,dp_class:myocard,dp_class:aorta,dp_class:splenic_vein
0,-3.0,-3.0,-3.0,-3.0,-3.0,-3.0,-0.47,-0.47,-0.47,1.0,0.01,-3.0,-3.0,-3.0,0.01




DataParamMode.DISABLED
0.08998820185661316




