In [1]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchmetrics
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import numpy as np
import time
import math
import lightning.pytorch as pl

import wandb
from evaluate import evaluate
from unet import UNet
from unet.unet_model_lightning import UNetLightning
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
from eevaluate import etest
from etrainer import train_model, config_data
from etrainerfunctions import rotate
import etrainer



In [2]:
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device=torch.device('cpu')
epsilon=0

sizeup = torchvision.transforms.Resize((550,550), interpolation=TF.InterpolationMode.NEAREST_EXACT)
sizedown=torchvision.transforms.Resize((450,450), interpolation=TF.InterpolationMode.NEAREST_EXACT)
resize=torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.NEAREST_EXACT)
deformation = torch.tensor([[0,0,0],[0,.5,0],[0,0,0]]).to(device=device, dtype=torch.float32)
deformation=torch.unsqueeze(deformation, dim=0)
deformation= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformation)
ydeformation =torch.full([1,512,512], 0, device=device, dtype=torch.float32)
deformation = torch.stack([deformation, ydeformation], dim=3)
deform = lambda tensor : TF.elastic_transform(tensor, deformation, TF.InterpolationMode.NEAREST, 0.0)


upscale=torchvision.transforms.Compose([sizeup,resize])
etransform=torchvision.transforms.Compose([sizedown,resize])
shrinkcrop=etrainer.compose(etrainer.pad24,sizedown)

def shift(x, shiftnum=1, axis=-1):
    x=torch.transpose(x, axis, -1)
    if shiftnum == 0:
        padded = x
    elif shiftnum > 0:
        #paddings = (0, shift, 0, 0, 0, 0)
        paddings = [0 for j in range(2*len(tuple(x.shape)))]
        paddings[1]=shiftnum
        paddings=tuple(paddings)
        padded = nn.functional.pad(x[..., shiftnum:], paddings)
    elif shiftnum < 0:
        #paddings = (-shift, 0, 0, 0, 0, 0)
        paddings = [0 for j in range(2*len(tuple(x.shape)))]
        paddings[0]=-shiftnum
        paddings=tuple(paddings)
        padded = nn.functional.pad(x[..., :shiftnum], paddings)
    else:
        raise ValueError
    return torch.transpose(padded, axis,-1)
def randshift(x):
    shiftnum = random.randint(-6,6)
    axis = random.randint(-2,-1)
    return shift(x, shiftnum, axis)
#This is for scaling
#efunctions=[[etransform, etransform, epsilon], [upscale,upscale,epsilon]] 
efunctions=[[torchvision.transforms.RandomRotation(10), torchvision.transforms.RandomRotation(10),epsilon]]
#efunctions=efunctions+[[lambda x : shift(x, shiftnum, axis), lambda x : shift(x, shiftnum, axis), epsilon] for shiftnum in range(-1,1,2) for axis in range(-1,1,2)]+[[etransform, etransform, epsilon], [upscale,upscale,epsilon]]
#efunctions += [[randshift, randshift, epsilon]]
#efunctions = efunctions+[[deform,deform,epsilon]]
#efunctions = efunctions + [[rotate(90), rotate(90),0]]
#efunctions = efunctions + [[torchvision.transforms.ElasticTransform(interpolation=TF.InterpolationMode.NEAREST), torchvision.transforms.ElasticTransform(interpolation=TF.InterpolationMode.NEAREST), epsilon]]
"""
for x in range(-1,1):
    for y in range(-1,1):
        for i in range(3):
            for j in range(3):
                for a in range(3):
                    for b in range(3):
                        deformationx = torch.tensor([[x if i==k else 0 for k in range(3)] if j==l else [0,0,0] for l in range(3)]).to(device=device, dtype=torch.float32)
                        deformationy = torch.tensor([[y if a==k else 0 for k in range(3)] if b==l else [0,0,0] for l in range(3)]).to(device=device, dtype=torch.float32)
                        deformationx=torch.unsqueeze(deformationx, dim=0)
                        deformationy=torch.unsqueeze(deformationy, dim=0)
                        deformationx= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformationx)
                        deformationy= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformationy)
                        deformation = torch.stack([deformationx, deformationy], dim=3)
                        deform = lambda tensor : TF.elastic_transform(tensor, deformation, TF.InterpolationMode.NEAREST, 0.0)
                        efunctions.append([deform,deform,0])
"""
args = {'epochs' : 50,
        'batch_size' : 8,
        'amp' : True,
        'bilinear' : False,
        'classes' : 3,
        'learning_rate' : 1e-6,
        'load': False,
        #'load' : "C:\\Users\\jjkjj\\Equivariant\\EquivariantUNet\\bumbling-sponge-27_checkpoints\\checkpoint_epoch121.pth",
        'class weights' : [1,1,3],
        'epochbreaks' : False,
        'break_length' : 5,
        'etransforms' : efunctions,
        'equivariance_measure' : 'l1',
        'equivariant' : False,
        'eqerror' : False,
        'augmented' : 'rangle',
        'Linf' : False,
        'eqweight' : .1,
        'n' : 1,
        'debugging' : False,
        'in_channels' : 3,
        'wandb_project' : 'Equivariant UNet',
        'test_on_epoch_end' : True,
        'test augmented' : False,
        'test augment' : 'model transforms',
        'save_checkpoint' : False,
        'eqweight_scheduler' : False,
        'eqweight_decay' : 1.1,
        'lr_scheduler' : 'cyclic',
        'min_lr' : 1e-9,
        'max_lr' : 1.5e-4,
        'product_loss' : False,
        'C1norm' : False,
        'C1weight' : 1,
        'Oxford' : True,
        'HeLa' : False,
       }
if args['equivariant'] != args['eqerror']:
    print('Equivariant and eqerror are different are you sure?')

model = UNet(args['in_channels'], args['classes'], **args).to(device=device, dtype = torch.float32)
model = model.to(memory_format=torch.channels_last)
angle = 5

#model = UNetLightning(args['in_channels'], args['classes'], **args).to(device=device)


if args['load']:
    state_dict = torch.load(args['load'], map_location=device)
    #del state_dict['mask_values']
    model.load_state_dict(state_dict)
    logging.info(f'Model loaded from {args["load"]}')
print(model.n)
print(model.Linf)
#print([param for param in list(model.parameters())])


1
False


In [7]:
etest(model, config_data(HeLa=args['HeLa'], Oxford=args['Oxford'], split='test',
                         augmented=args['test augmented'], aug_transforms=[[rotate(angle, fill=1),rotate(angle, fill=1),0,1]]), device, True,
       epoch=1, experiment_started=False,
                         etransforms=[[rotate(angle, fill = 1),rotate(angle, fill=1),0,1]],
                         test_augment = args['test augment'], angle = angle, class_weights = None, wandb_project = 'Equivariant UNet')

                                                                                                                       

KeyboardInterrupt: 

In [None]:
train_loader = config_data(**args)
train_model(model, device, train_loader, **args)

[34m[1mwandb[0m: Currently logged in as: [33mjjkjjk23[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/50: 100%|█| 3480/3480 [03:21<00:00, 17.31img/s, Cumulative Dice=[0.7646768675453361, 0.8417441009104937, 0.3478
Epoch 2/50: 100%|█| 3480/3480 [03:42<00:00, 15.67img/s, Cumulative Dice=[0.8419741481200032, 0.8944961483451142, 0.5281
Epoch 3/50: 100%|█| 3480/3480 [03:54<00:00, 14.85img/s, Cumulative Dice=[0.8309374424232834, 0.8887414850037674, 0.5198
Epoch 4/50: 100%|█| 3480/3480 [03:50<00:00, 15.09img/s, Cumulative Dice=[0.8445855516126786, 0.9045857513087919, 0.5459
Epoch 5/50: 100%|█| 3480/3480 [03:49<00:00, 15.15img/s, Cumulative Dice=[0.8826471347918455, 0.92305296179892, 0.608765
Epoch 6/50: 100%|█| 3480/3480 [03:49<00:00, 15.14img/s, Cumulative Dice=[0.8461500255540869, 0.9001461151002468, 0.5523
Epoch 7/50: 100%|█| 3480/3480 [03:50<00:00, 15.07img/s, Cumulative Dice=[0.8965620878099025, 0.9318929629764338, 0.6326
Epoch 8/50: 100%|█| 3480/3480 [03:50<00:00, 15.09img/s, Cumulative Dice=[0.8651940781494667, 0.9108806690950503, 0.5751
Epoch 9/50: 100%|█| 3480/3480 [03:51<00:

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "caching_allocator"
train_loader = config_data(HeLa=True, Oxford=False, **args)
trainer = pl.Trainer(max_epochs=1, accelerator='gpu')
trainer.fit(model, train_dataloaders=train_loader)

In [3]:
import PIL
import torch
import torchvision

class HeLaDataset(torch.utils.data.Dataset):
    def __init__(self, train_file,label_file, n_images, transform=None, target_transform=None):
        super().__init__()
        self.images=PIL.Image.open(train_file)
        self.images.load()
        self.labels=PIL.Image.open(label_file)
        self.labels.load()
        #self.tensor=torchvision.transforms.PILToTensor(PIL.Image.open(file))
        self.n_images=n_images
        self.transform=transform
        self.target_transform=target_transform
        
    def __getitem__(self, idx):
        print(idx)
        self.images.seek(idx)
        if self.transform:
            return self.transform(torch.from_numpy(np.array(self.images))), self.target_transform(torch.from_numpy(np.array(self.labels)))
        else:
            return torch.from_numpy(np.array(self.images)), torch.from_numpy(np.array(self.labels))

        #return self
    def __len__(self):
        return self.n_images


In [2]:
dataset= HeLaDataset(f"C:\\Users\\jjkjj\\Equivariant\\ISBI-2012-challenge\\train-volume.tif","C:\\Users\\jjkjj\\Equivariant\\ISBI-2012-challenge\\train-labels.tif", 30)
for j, image in enumerate(dataset):
    print(image)

    

NameError: name 'HeLaDataset' is not defined

In [3]:
#
time0=time.time()
#@torch.inference_mode()
def timemodel():
    #model.eval()
    for batch in config_data(HeLa=True, Oxford=False, augmented='randcombo'):
        image, label = batch
        imager= torchvision.transforms.ToPILImage()
        image = imager(torch.squeeze(image, dim=0))
        label = imager(torch.squeeze(label, dim=0))
        image.show()
        label.show()
    return None
"""
        x, y = batch
        x=x.to(device=self.device)
        y=y.to(device=self.device)
        y_hat = self(x).to(device=self.device)
        #randomval = torch.from_numpy(np.random.default_rng().random(size=(1,1,512,512))).to(torch.float32)
        randomval=etrainer.sampler((1,512,512),n=1, cuda=False)
        loss = F.binary_cross_entropy_with_logits(y_hat, y.to(torch.float32))
        if self.equivariant:
            for f in self.etransforms:
                loss+= torch.mean(torch.abs(self(f[1](randomval))-f[0](self(randomval))))
        return loss
"""
timemodel()


KeyboardInterrupt: 

In [6]:
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device=torch.device('cpu')
epsilon=0

sizeup = torchvision.transforms.Resize((550,550), interpolation=TF.InterpolationMode.NEAREST_EXACT)
sizedown=torchvision.transforms.Resize((450,450), interpolation=TF.InterpolationMode.NEAREST_EXACT)
resize=torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.NEAREST_EXACT)
deformation = torch.tensor([[0,0,0],[0,.5,0],[0,0,0]]).to(device=device, dtype=torch.float32)
deformation=torch.unsqueeze(deformation, dim=0)
deformation= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformation)
ydeformation =torch.full([1,512,512], 0, device=device, dtype=torch.float32)
deformation = torch.stack([deformation, ydeformation], dim=3)
deform = lambda tensor : TF.elastic_transform(tensor, deformation, TF.InterpolationMode.NEAREST, 0.0)

def rotate(angle):
    return lambda inputs : torchvision.transforms.functional.rotate(inputs, angle)


upscale=torchvision.transforms.Compose([sizeup,resize])
etransform=torchvision.transforms.Compose([sizedown,resize])
shrinkcrop=etrainer.compose(etrainer.pad24,sizedown)

def shift(x, shiftnum=1, axis=-1):
    x=torch.transpose(x, axis, -1)
    if shiftnum == 0:
        padded = x
    elif shiftnum > 0:
        #paddings = (0, shift, 0, 0, 0, 0)
        paddings = [0 for j in range(2*len(tuple(x.shape)))]
        paddings[1]=shiftnum
        paddings=tuple(paddings)
        padded = nn.functional.pad(x[..., shiftnum:], paddings)
    elif shiftnum < 0:
        #paddings = (-shift, 0, 0, 0, 0, 0)
        paddings = [0 for j in range(2*len(tuple(x.shape)))]
        paddings[0]=-shiftnum
        paddings=tuple(paddings)
        padded = nn.functional.pad(x[..., :shiftnum], paddings)
    else:
        raise ValueError
    return torch.transpose(padded, axis,-1)
def randshift(x):
    shiftnum = random.randint(-6,6)
    axis = random.randint(-2,-1)
    return shift(x, shiftnum, axis)
#This is for scaling
#efunctions=[[etransform, etransform, epsilon], [upscale,upscale,epsilon]] 
efunctions=[[torchvision.transforms.RandomRotation(10), torchvision.transforms.RandomRotation(10),epsilon]]
#efunctions=efunctions+[[lambda x : shift(x, shiftnum, axis), lambda x : shift(x, shiftnum, axis), epsilon] for shiftnum in range(-1,1,2) for axis in range(-1,1,2)]+[[etransform, etransform, epsilon], [upscale,upscale,epsilon]]
#efunctions += [[randshift, randshift, epsilon]]
#efunctions = efunctions+[[deform,deform,epsilon]]
#efunctions = efunctions + [[rotate(90), rotate(90),0]]
#efunctions = efunctions + [[torchvision.transforms.ElasticTransform(interpolation=TF.InterpolationMode.NEAREST), torchvision.transforms.ElasticTransform(interpolation=TF.InterpolationMode.NEAREST), epsilon]]
"""
for x in range(-1,1):
    for y in range(-1,1):
        for i in range(3):
            for j in range(3):
                for a in range(3):
                    for b in range(3):
                        deformationx = torch.tensor([[x if i==k else 0 for k in range(3)] if j==l else [0,0,0] for l in range(3)]).to(device=device, dtype=torch.float32)
                        deformationy = torch.tensor([[y if a==k else 0 for k in range(3)] if b==l else [0,0,0] for l in range(3)]).to(device=device, dtype=torch.float32)
                        deformationx=torch.unsqueeze(deformationx, dim=0)
                        deformationy=torch.unsqueeze(deformationy, dim=0)
                        deformationx= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformationx)
                        deformationy= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformationy)
                        deformation = torch.stack([deformationx, deformationy], dim=3)
                        deform = lambda tensor : TF.elastic_transform(tensor, deformation, TF.InterpolationMode.NEAREST, 0.0)
                        efunctions.append([deform,deform,0])
"""
args = {'epochs' : 100,
        'batch_size' : 1,
        'amp' : True,
        'bilinear' : False,
        'classes' : 1,
        'learning_rate' : 1e-6,
        'load': False,
        #'load' : "C:\\Users\\jjkjj\\Equivariant\\EquivariantUNet\\bumbling-sponge-27_checkpoints\\checkpoint_epoch121.pth",
        'class weights' : [1,1,3],
        'epochbreaks' : False,
        'break_length' : 5,
        'etransforms' : efunctions,
        'equivariance_measure' : 'l1',
        'equivariant' : True,
        'eqerror' : True,
        'augmented' : 'rangle',
        'Linf' : False,
        'eqweight' : 100,
        'n' : 1,
        'debugging' : False,
        'in_channels' : 1,
        'wandb_project' : 'HeLa EUNet',
        'test_on_epoch_end' : True,
        'test augmented' : 'True no identity',
        'test augment' : 'fixed rotations',
        'save_checkpoint' : True,
        'eqweight_scheduler' : False,
        'eqweight_decay' : 1.1,
        'lr_scheduler' : 'cyclic',
        'min_lr' : 1e-9,
        'max_lr' : 1.5e-4,
        'product_loss' : False
        'C1norm' : True
        'C1weight' : 1
       }
if args['equivariant'] != args['eqerror']:
    print('Equivariant and eqerror are different are you sure?')

model = UNet(args['in_channels'], args['classes'], **args).to(device=device)
model = model.to(memory_format=torch.channels_last)
etest(model, config_data(HeLa=args['HeLa'], Oxford=args['Oxford'], split='test', augmented=args['test augmented'], 
    aug_transforms=[[rotate(angle),rotate(angle),0,1]], **kwargs), device, amp, run_id=experiment.id, epoch=epoch,
    experiment_started=False, 
    etransforms=[[rotate(angle),rotate(angle),0,1]], test_augment = kwargs['test augment'], angle = angle)

#model = UNetLightning(args['in_channels'], args['classes'], **args).to(device=device)


if args['load']:
    state_dict = torch.load(args['load'], map_location=device)
    #del state_dict['mask_values']
    model.load_state_dict(state_dict)
    logging.info(f'Model loaded from {args["load"]}')
print(model.n)
print(model.Linf)
print(model.state_dict()['inc.double_conv.0.weight'])


1
False
torch.Size([64, 1, 3, 3])


In [None]:
for j in range(10):
    model = UNet(args['in_channels'], args['classes'], **args2).to(device=device)
    model = model.to(memory_format=torch.channels_last)
    train_loader = config_data(HeLa=True, Oxford=False, **args2)
    train_model(model, device, train_loader, **args2)

0,1
Angle -10 Test Loss,██▇▄▂▃▃▂▂▂▁▂▁▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▁
Angle -20 Test Loss,██▇▄▃▃▃▂▂▃▂▂▂▁▁▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▂▂▁▁▁▂▃▁▂▂
Angle -5 Test Loss,███▄▂▃▃▂▂▂▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▁
Angle 0 Test Loss,███▄▂▃▃▂▂▂▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁▁▁
Angle 10 Test Loss,██▇▄▂▃▃▂▂▂▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁
Angle 20 Test Loss,██▇▄▃▃▃▂▂▃▂▃▂▁▁▂▂▂▁▂▂▂▂▂▂▁▁▁▁▂▂▂▁▂▁▂▂▁▂▁
Angle 5 Test Loss,███▄▂▃▃▂▁▂▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▁
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▁▁▁▁▂▁▁▁▂▂▁▂▃▁▂▃▁▄▃▇▅▆▆▂▅▂▂█▆▂▄▄▁▂▄▁█▇▃▄
Test Equivariance Error,▂▁▁▃▄▄▅▅▅▅▅▆▆▅▆▆▄▅▆▆▆▆▆▆▇▆▆▇▇▅▇▇▅▆▇▄▆▇██

0,1
Angle -10 Test Loss,0.46479
Angle -20 Test Loss,0.53581
Angle -5 Test Loss,0.4252
Angle 0 Test Loss,0.37595
Angle 10 Test Loss,0.45952
Angle 20 Test Loss,0.51634
Angle 5 Test Loss,0.42065
Equivariance Weight,100.0
Random Equivariance Error,0.14346
Test Equivariance Error,0.10431


Epoch 1/100: 100%|██████████| 30/30 [00:18<00:00,  1.62img/s, Cumulative Dice=[0.8787059187889099], loss (batch)=0.672]
Epoch 2/100: 100%|██████████| 30/30 [00:19<00:00,  1.51img/s, Cumulative Dice=[0.9061132669448853], loss (batch)=0.539]
Epoch 3/100: 100%|██████████| 30/30 [00:20<00:00,  1.49img/s, Cumulative Dice=[0.9268796443939209], loss (batch)=0.411]
Epoch 4/100: 100%|██████████| 30/30 [00:20<00:00,  1.50img/s, Cumulative Dice=[0.9451736211776733], loss (batch)=0.362]
Epoch 5/100: 100%|██████████| 30/30 [00:20<00:00,  1.50img/s, Cumulative Dice=[0.9448870420455933], loss (batch)=0.312]
Epoch 6/100: 100%|██████████| 30/30 [00:20<00:00,  1.50img/s, Cumulative Dice=[0.9453360438346863], loss (batch)=0.345]
Epoch 7/100: 100%|██████████| 30/30 [00:20<00:00,  1.50img/s, Cumulative Dice=[0.9481698870658875], loss (batch)=0.285]
Epoch 8/100: 100%|██████████| 30/30 [00:19<00:00,  1.51img/s, Cumulative Dice=[0.9465432167053223], loss (batch)=0.293]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▃▂▂▁▂▁▂▂▁▃▂▃▃▃▃▃▃▃▃▃▃▃▄▅▅▅▅▅▄▅▆▅▅▅▆▅▇▇▇
Angle -20 Test Loss,█▄▂▂▁▂▁▂▂▁▃▂▃▃▃▃▃▃▃▃▄▃▃▄▅▅▅▅▅▄▅▆▅▅▅▆▅▇▇▇
Angle -5 Test Loss,█▃▂▁▁▂▁▂▂▁▃▂▃▃▃▃▃▃▃▃▃▃▃▄▅▅▅▅▅▄▅▆▅▅▅▆▅▇▇▇
Angle 0 Test Loss,█▄▂▁▁▁▁▁▂▁▃▂▃▃▃▃▃▃▃▃▃▃▃▄▄▄▅▅▅▄▅▆▅▅▅▆▅▇▇▇
Angle 10 Test Loss,█▃▂▁▁▂▁▂▂▁▃▂▃▃▃▃▃▃▃▃▃▃▃▄▅▄▅▅▅▄▄▆▅▅▅▅▅▇▇▆
Angle 20 Test Loss,█▄▂▂▁▂▁▂▂▁▃▂▃▃▃▃▃▃▃▃▃▃▃▄▅▅▅▅▅▄▅▆▅▅▅▆▅▇▇▇
Angle 5 Test Loss,█▃▂▁▁▁▁▂▂▁▃▂▃▃▃▃▃▃▃▃▃▃▃▄▅▄▅▅▅▄▅▆▅▅▅▆▅▇▇▆
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▁▁▁▁▁▅▃▂▇▅▅▆▅▄▄▅▃▃▅▆▄▅▅▄▆▅▄▆▆▂▇▆▄█▇▂▄▇▆▄
Test Equivariance Error,▁▁▁▁▂▂▁▂▂▂▂▃▂▆▃▂▅▃▃▂▃▃▃▃▅▄▄▅▅█▆▅█▅▆▇▅▆▅▆

0,1
Angle -10 Test Loss,0.67371
Angle -20 Test Loss,0.72056
Angle -5 Test Loss,0.65447
Angle 0 Test Loss,0.63144
Angle 10 Test Loss,0.66324
Angle 20 Test Loss,0.7111
Angle 5 Test Loss,0.64511
Equivariance Weight,100.0
Random Equivariance Error,0.48541
Test Equivariance Error,0.21854


Epoch 1/100: 100%|███████████| 30/30 [00:19<00:00,  1.55img/s, Cumulative Dice=[0.8401538729667664], loss (batch)=0.82]
Epoch 2/100: 100%|██████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.8851718902587891], loss (batch)=0.697]
Epoch 3/100: 100%|██████████| 30/30 [00:20<00:00,  1.43img/s, Cumulative Dice=[0.9238747954368591], loss (batch)=0.494]
Epoch 4/100: 100%|██████████| 30/30 [00:21<00:00,  1.40img/s, Cumulative Dice=[0.9379076361656189], loss (batch)=0.412]
Epoch 5/100: 100%|██████████| 30/30 [00:21<00:00,  1.41img/s, Cumulative Dice=[0.9477435946464539], loss (batch)=0.389]
Epoch 6/100: 100%|██████████| 30/30 [00:22<00:00,  1.36img/s, Cumulative Dice=[0.9416313171386719], loss (batch)=0.357]
Epoch 7/100: 100%|██████████| 30/30 [00:21<00:00,  1.42img/s, Cumulative Dice=[0.9399658441543579], loss (batch)=0.304]
Epoch 8/100: 100%|████████████| 30/30 [00:21<00:00,  1.38img/s, Cumulative Dice=[0.9507479667663574], loss (batch)=0.3]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▅▂▁▁▁▁▂▂▂▂▃▂▃▂▃▂▂▃▃▃▂▂▃▃▃▄▄▄▅▄▄▄▄▄▅▅▅▅▅
Angle -20 Test Loss,█▅▂▁▂▁▁▂▁▂▂▃▂▂▂▂▁▂▂▂▃▂▂▃▃▃▄▄▄▄▄▄▄▃▄▅▅▅▅▅
Angle -5 Test Loss,█▄▂▁▁▁▁▂▂▂▂▃▂▃▂▃▂▂▃▃▃▂▂▃▄▄▄▄▄▅▄▄▄▄▄▅▅▅▅▅
Angle 0 Test Loss,█▄▂▁▁▁▁▂▁▂▂▃▂▂▂▃▂▂▃▂▃▂▂▃▃▃▄▄▄▄▄▄▄▄▄▅▄▅▅▅
Angle 10 Test Loss,█▄▂▁▁▁▁▂▂▂▂▃▂▂▂▃▂▂▃▃▃▂▂▃▃▃▄▄▄▄▄▄▃▄▄▅▅▅▅▅
Angle 20 Test Loss,█▅▂▁▂▁▁▂▁▁▂▃▂▂▂▂▁▂▂▃▃▂▂▃▃▃▄▄▄▅▄▄▃▄▄▅▅▅▅▅
Angle 5 Test Loss,█▄▂▁▁▁▁▂▂▂▂▃▂▂▂▃▂▂▃▃▃▂▂▃▃▃▄▄▄▅▄▄▄▄▄▅▅▅▅▅
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▂▁▁▁▂▂▁▂▅▅▁▆▅▁▂▅▃▅▄▅▇▅▅▃█▅▄▇▅▂▇▆▂▆▅▁▄▅▄▄
Test Equivariance Error,▁▁▂▁▁▂▂▂▂▂▂▂▂▄▃▂▄▂▃▃▃▃▃▃▄▃▄▅▅▆▅▄▇▅▅█▅▆▅▅

0,1
Angle -10 Test Loss,0.68142
Angle -20 Test Loss,0.71928
Angle -5 Test Loss,0.66518
Angle 0 Test Loss,0.63967
Angle 10 Test Loss,0.67068
Angle 20 Test Loss,0.70856
Angle 5 Test Loss,0.66017
Equivariance Weight,100.0
Random Equivariance Error,0.39304
Test Equivariance Error,0.21961


Epoch 1/100: 100%|██████████| 30/30 [00:19<00:00,  1.52img/s, Cumulative Dice=[0.8605251312255859], loss (batch)=0.671]
Epoch 2/100: 100%|██████████| 30/30 [00:21<00:00,  1.41img/s, Cumulative Dice=[0.8666316270828247], loss (batch)=0.629]
Epoch 3/100: 100%|██████████| 30/30 [00:21<00:00,  1.40img/s, Cumulative Dice=[0.9210823774337769], loss (batch)=0.438]
Epoch 4/100: 100%|██████████| 30/30 [00:22<00:00,  1.35img/s, Cumulative Dice=[0.9391547441482544], loss (batch)=0.382]
Epoch 5/100: 100%|██████████| 30/30 [00:21<00:00,  1.41img/s, Cumulative Dice=[0.9461740851402283], loss (batch)=0.376]
Epoch 6/100: 100%|██████████| 30/30 [00:20<00:00,  1.43img/s, Cumulative Dice=[0.9392403364181519], loss (batch)=0.431]
Epoch 7/100: 100%|██████████| 30/30 [00:21<00:00,  1.42img/s, Cumulative Dice=[0.9470083117485046], loss (batch)=0.305]
Epoch 8/100: 100%|██████████| 30/30 [00:21<00:00,  1.40img/s, Cumulative Dice=[0.9424505233764648], loss (batch)=0.392]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▄▂▁▁▁▂▃▁▂▂▂▂▃▃▃▃▂▂▃▂▃▂▄▄▄▄▅▄▄▅▅▄▄▅▄▄▄▅▆
Angle -20 Test Loss,█▄▂▁▁▂▂▃▁▂▂▂▃▃▃▃▃▃▂▃▃▃▂▄▅▄▅▅▅▄▅▅▅▄▅▄▅▅▆▆
Angle -5 Test Loss,█▄▂▁▁▁▂▃▁▂▂▂▂▃▃▃▃▂▂▃▂▃▂▄▄▄▄▅▄▄▅▅▄▄▅▄▄▅▅▆
Angle 0 Test Loss,█▄▂▂▁▁▂▂▁▂▂▂▂▂▃▂▃▂▂▃▂▃▂▃▄▄▄▄▄▄▄▄▄▃▄▄▄▄▅▅
Angle 10 Test Loss,█▄▂▁▁▁▂▃▁▂▂▂▂▃▂▂▃▂▂▃▂▃▂▄▄▄▄▅▄▄▅▄▄▃▄▄▄▄▅▆
Angle 20 Test Loss,█▄▂▁▂▁▂▃▁▂▂▂▂▃▂▂▃▂▂▃▂▃▂▄▄▄▄▅▄▄▅▅▄▃▅▄▄▄▅▆
Angle 5 Test Loss,█▄▂▁▁▁▂▃▁▂▂▂▂▃▃▃▃▂▂▃▂▃▂▄▄▄▄▅▄▄▅▄▄▃▄▄▄▄▅▆
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▂▁▁▁▂▃▄▁▄▄▃▅▅▂▅▄▂▃▃▄▆█▄▅█▄▃▆▄▂▇▅▁▄▅▂▅▆▅▅
Test Equivariance Error,▁▂▂▁▁▂▁▂▂▂▂▂▂▆▃▂▅▂▃▂▃▃▃▃▄▄▄▅▄█▅▄█▄▅▆▅▅▅▅

0,1
Angle -10 Test Loss,0.67497
Angle -20 Test Loss,0.71884
Angle -5 Test Loss,0.65176
Angle 0 Test Loss,0.62288
Angle 10 Test Loss,0.66445
Angle 20 Test Loss,0.70418
Angle 5 Test Loss,0.6424
Equivariance Weight,100.0
Random Equivariance Error,0.39328
Test Equivariance Error,0.22009


Epoch 1/100: 100%|███████████| 30/30 [00:18<00:00,  1.58img/s, Cumulative Dice=[0.895106315612793], loss (batch)=0.573]
Epoch 2/100: 100%|██████████| 30/30 [00:21<00:00,  1.40img/s, Cumulative Dice=[0.9224526882171631], loss (batch)=0.479]
Epoch 3/100: 100%|██████████| 30/30 [00:21<00:00,  1.42img/s, Cumulative Dice=[0.9295753240585327], loss (batch)=0.363]
Epoch 4/100: 100%|██████████| 30/30 [00:21<00:00,  1.40img/s, Cumulative Dice=[0.9429773688316345], loss (batch)=0.311]
Epoch 5/100: 100%|██████████| 30/30 [00:21<00:00,  1.41img/s, Cumulative Dice=[0.9405688047409058], loss (batch)=0.364]
Epoch 6/100: 100%|██████████| 30/30 [00:21<00:00,  1.41img/s, Cumulative Dice=[0.9422876238822937], loss (batch)=0.288]
Epoch 7/100: 100%|██████████| 30/30 [00:20<00:00,  1.43img/s, Cumulative Dice=[0.9424943923950195], loss (batch)=0.251]
Epoch 8/100: 100%|██████████| 30/30 [00:21<00:00,  1.42img/s, Cumulative Dice=[0.9477333426475525], loss (batch)=0.259]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▃▂▁▁▁▂▁▂▃▃▃▃▃▄▃▃▃▂▃▃▄▄▄▃▆▆▆▅▆▅▆▆▅▇▆▆▆▇▇
Angle -20 Test Loss,█▃▂▂▂▁▂▁▂▄▃▃▃▃▄▂▃▃▂▃▃▃▄▄▃▅▆▆▅▅▅▆▆▅▇▆▆▆▇▇
Angle -5 Test Loss,█▃▂▁▁▁▂▁▂▃▃▃▃▃▄▃▃▃▂▃▃▄▄▄▃▆▆▆▆▆▆▆▆▆▇▆▆▆▇█
Angle 0 Test Loss,█▄▂▁▁▁▂▁▂▃▂▃▃▄▄▃▃▃▃▃▃▄▃▄▄▅▆▆▆▆▅▆▆▆▆▆▆▇▇▇
Angle 10 Test Loss,█▃▂▁▁▁▂▁▂▄▂▃▃▃▄▃▃▃▂▃▃▄▃▄▃▅▆▆▅▆▅▆▆▅▇▆▅▆▇▇
Angle 20 Test Loss,█▃▂▂▁▁▂▁▁▄▂▃▃▃▃▃▃▃▂▃▃▃▄▄▃▅▆▆▅▅▅▆▅▅▇▅▅▅▇▇
Angle 5 Test Loss,█▃▂▁▁▁▂▁▂▃▃▃▃▃▄▃▃▃▂▃▃▄▃▄▃▅▆▆▆▆▅▆▆▆▇▆▆▆▇▇
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▁▁▁▂▃▃▁▁▃▄▅▇▄▂▅▄▁▆▆▃▅█▆▄▆▅▄▅▄▁▆▃▁▅▆▁▅▃▂▃
Test Equivariance Error,▂▁▂▂▂▂▁▂▂▂▂▃▂▅▃▃▄▃▃▂▃▃▃▃▅▃▄▅▅▇▅▄▇▄▅█▅▆▅▅

0,1
Angle -10 Test Loss,0.67314
Angle -20 Test Loss,0.71783
Angle -5 Test Loss,0.65762
Angle 0 Test Loss,0.6314
Angle 10 Test Loss,0.66797
Angle 20 Test Loss,0.71367
Angle 5 Test Loss,0.65259
Equivariance Weight,100.0
Random Equivariance Error,0.23016
Test Equivariance Error,0.23303


Epoch 1/100: 100%|███████████| 30/30 [00:20<00:00,  1.43img/s, Cumulative Dice=[0.8485543131828308], loss (batch)=0.76]
Epoch 2/100: 100%|██████████| 30/30 [00:22<00:00,  1.36img/s, Cumulative Dice=[0.8826529383659363], loss (batch)=0.568]
Epoch 3/100: 100%|██████████| 30/30 [00:21<00:00,  1.41img/s, Cumulative Dice=[0.9190162420272827], loss (batch)=0.492]
Epoch 4/100: 100%|██████████| 30/30 [00:21<00:00,  1.43img/s, Cumulative Dice=[0.9390006065368652], loss (batch)=0.408]
Epoch 5/100: 100%|███████████| 30/30 [00:21<00:00,  1.42img/s, Cumulative Dice=[0.941992461681366], loss (batch)=0.362]
Epoch 6/100: 100%|██████████| 30/30 [00:21<00:00,  1.42img/s, Cumulative Dice=[0.9427234530448914], loss (batch)=0.337]
Epoch 7/100: 100%|███████████| 30/30 [00:21<00:00,  1.40img/s, Cumulative Dice=[0.9448760747909546], loss (batch)=0.27]
Epoch 8/100: 100%|██████████| 30/30 [00:21<00:00,  1.42img/s, Cumulative Dice=[0.9456708431243896], loss (batch)=0.301]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▄▂▁▂▁▁▁▂▁▂▂▂▂▃▂▂▂▃▃▂▄▃▃▄▅▄▄▅▅▅▄▄▅▅▆▅▅▅▆
Angle -20 Test Loss,█▄▂▁▂▂▂▁▂▂▂▁▂▂▃▂▂▂▃▃▃▄▃▃▄▅▄▅▅▅▆▅▅▅▅▆▅▅▆▆
Angle -5 Test Loss,█▄▂▁▂▁▁▁▂▁▂▂▂▃▃▂▂▂▃▃▂▃▃▃▄▅▄▅▅▅▅▄▄▅▅▅▅▅▅▆
Angle 0 Test Loss,█▄▂▁▁▁▁▁▂▁▂▂▂▂▃▂▂▂▂▂▂▃▃▃▄▄▄▄▅▄▅▄▄▄▅▅▅▄▅▅
Angle 10 Test Loss,█▄▂▁▂▁▁▁▂▁▂▂▂▂▃▂▂▂▂▃▂▃▃▃▄▅▄▄▅▄▅▄▄▅▅▅▅▄▅▆
Angle 20 Test Loss,█▄▂▁▂▁▁▁▂▁▂▂▂▂▃▂▂▂▂▃▂▄▃▃▄▅▄▄▅▄▅▄▅▅▅▆▅▄▅▆
Angle 5 Test Loss,█▄▂▁▂▁▁▁▂▁▂▂▂▂▃▂▂▂▃▃▂▃▃▃▄▅▄▄▅▄▅▄▄▅▅▅▅▄▅▆
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▂▁▁▁▁▃▂▃▆▅▄▇▅▂▅▅▂▃▇▄▅▇▅▅█▅▅▇▅▂█▅▁▆▇▁▅▅▄▄
Test Equivariance Error,▁▁▂▂▃▂▂▂▂▂▂▃▃▆▃▃▅▃▃▃▃▄▃▄▅▄▅▆▅█▇▆▇▆▆█▆▇▅▆

0,1
Angle -10 Test Loss,0.68301
Angle -20 Test Loss,0.72508
Angle -5 Test Loss,0.6647
Angle 0 Test Loss,0.63746
Angle 10 Test Loss,0.67358
Angle 20 Test Loss,0.71466
Angle 5 Test Loss,0.65629
Equivariance Weight,100.0
Random Equivariance Error,0.33762
Test Equivariance Error,0.21899


Epoch 1/100: 100%|██████████| 30/30 [00:19<00:00,  1.58img/s, Cumulative Dice=[0.8900338411331177], loss (batch)=0.738]
Epoch 2/100: 100%|██████████| 30/30 [00:20<00:00,  1.49img/s, Cumulative Dice=[0.9043092727661133], loss (batch)=0.471]
Epoch 3/100: 100%|██████████| 30/30 [00:20<00:00,  1.47img/s, Cumulative Dice=[0.9315717220306396], loss (batch)=0.385]
Epoch 4/100: 100%|██████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.9308049082756042], loss (batch)=0.369]
Epoch 5/100: 100%|███████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.941523015499115], loss (batch)=0.319]
Epoch 6/100: 100%|██████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.9434401392936707], loss (batch)=0.269]
Epoch 7/100: 100%|██████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.9471587538719177], loss (batch)=0.283]
Epoch 8/100: 100%|██████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.9473462104797363], loss (batch)=0.295]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▄▁▁▂▁▃▁▁▂▃▂▂▃▃▂▃▃▃▂▄▃▃▄▄▅▅▅▆▅▅▅▅▄▅▅▆▆▆▇
Angle -20 Test Loss,█▄▁▁▃▂▃▁▁▂▃▂▂▃▃▂▄▃▃▂▄▃▃▄▄▅▅▅▆▅▅▅▅▄▅▅▆▆▆▇
Angle -5 Test Loss,█▄▁▁▂▁▂▂▂▂▃▂▂▃▃▃▃▃▃▂▄▃▃▄▄▅▅▆▆▅▅▅▅▅▅▅▆▇▇▇
Angle 0 Test Loss,█▄▂▁▂▁▂▂▂▂▃▂▂▃▃▃▃▃▃▂▃▃▃▄▄▅▅▅▆▅▅▅▅▅▅▅▆▆▆▇
Angle 10 Test Loss,█▄▁▁▂▁▃▂▂▂▃▂▂▃▃▂▃▃▃▂▄▃▃▄▄▅▅▅▆▅▅▅▄▄▅▅▆▆▆▇
Angle 20 Test Loss,█▄▁▁▃▂▃▂▁▂▃▂▂▃▂▂▃▃▃▂▄▂▃▄▄▅▅▅▆▅▅▅▄▄▅▅▆▆▆▇
Angle 5 Test Loss,█▄▁▁▂▁▂▂▂▂▃▂▂▃▃▃▃▃▃▂▄▃▃▄▄▅▅▅▆▅▅▅▅▄▅▅▆▆▆▇
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▁▁▁▁▂▅▂▃█▇▂▆▇▁▄▃▁▃▃▄▃▅▃▄▅▄▃▇▃▂▅▄▁▃▅▁▃▃▂▃
Test Equivariance Error,▁▁▁▁▁▂▁▁▂▁▂▂▂▅▂▂▆▂▂▂▂▃▃▃▄▄▄▅▄▇▅▄▆▄▄█▄▆▅▅

0,1
Angle -10 Test Loss,0.67178
Angle -20 Test Loss,0.70311
Angle -5 Test Loss,0.66008
Angle 0 Test Loss,0.63578
Angle 10 Test Loss,0.66318
Angle 20 Test Loss,0.69899
Angle 5 Test Loss,0.65053
Equivariance Weight,100.0
Random Equivariance Error,0.13094
Test Equivariance Error,0.23613


Epoch 1/100: 100%|██████████| 30/30 [00:18<00:00,  1.59img/s, Cumulative Dice=[0.8153320550918579], loss (batch)=0.821]
Epoch 2/100: 100%|██████████| 30/30 [00:19<00:00,  1.50img/s, Cumulative Dice=[0.9041799902915955], loss (batch)=0.545]
Epoch 3/100: 100%|██████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.9278862476348877], loss (batch)=0.494]
Epoch 4/100: 100%|██████████| 30/30 [00:20<00:00,  1.47img/s, Cumulative Dice=[0.9314574599266052], loss (batch)=0.407]
Epoch 5/100: 100%|██████████| 30/30 [00:20<00:00,  1.47img/s, Cumulative Dice=[0.9394694566726685], loss (batch)=0.349]
Epoch 6/100: 100%|██████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.9462690353393555], loss (batch)=0.324]
Epoch 7/100: 100%|██████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.9441965818405151], loss (batch)=0.303]
Epoch 8/100: 100%|██████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.9468040466308594], loss (batch)=0.266]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▄▂▂▁▁▁▁▁▁▁▂▂▂▃▂▃▂▃▄▃▃▂▃▃▄▄▄▄▄▄▃▅▄▃▃▄▅▅▆
Angle -20 Test Loss,█▅▂▃▁▁▁▁▁▂▂▃▂▂▃▃▄▃▄▄▃▄▂▃▃▄▅▅▄▅▄▄▅▄▄▄▅▅▅▆
Angle -5 Test Loss,█▄▂▂▁▁▁▁▁▁▁▂▂▂▃▂▃▂▃▄▂▃▂▃▃▄▄▄▄▄▄▄▅▄▄▃▅▅▅▆
Angle 0 Test Loss,█▄▂▂▁▁▁▂▁▁▂▂▂▂▃▂▃▂▃▃▂▃▂▃▃▄▄▄▄▄▄▄▅▄▄▄▅▅▅▅
Angle 10 Test Loss,█▄▁▂▁▁▁▁▁▁▁▂▂▂▃▂▃▂▃▃▃▃▂▃▃▄▄▄▄▄▄▃▅▄▃▃▄▄▅▅
Angle 20 Test Loss,█▄▂▂▁▁▁▁▁▁▁▂▂▂▃▂▄▂▄▄▃▃▂▃▃▄▄▄▄▄▄▃▅▄▃▃▄▄▅▆
Angle 5 Test Loss,█▄▂▂▁▁▁▁▁▁▁▂▂▂▃▂▃▂▃▃▂▃▂▃▃▄▄▄▄▄▄▃▅▄▄▃▅▅▅▅
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▂▁▁▂▁▃▁▃▂▃▅█▅▂▅▄▁▄█▆▃▇▃▃▇▃▂▅▂▂▇▅▂▃█▂▄▅▂▄
Test Equivariance Error,▃▁▁▁▁▁▁▁▂▁▁▂▂▄▂▂▅▂▂▂▂▃▂▃▄▃▄▄▄▇▅▃█▄▄█▅▅▄▅

0,1
Angle -10 Test Loss,0.69201
Angle -20 Test Loss,0.73241
Angle -5 Test Loss,0.67466
Angle 0 Test Loss,0.65589
Angle 10 Test Loss,0.68312
Angle 20 Test Loss,0.72009
Angle 5 Test Loss,0.67045
Equivariance Weight,100.0
Random Equivariance Error,0.23911
Test Equivariance Error,0.23368


Epoch 1/100: 100%|██████████| 30/30 [00:19<00:00,  1.58img/s, Cumulative Dice=[0.8758037090301514], loss (batch)=0.742]
Epoch 2/100: 100%|███████████| 30/30 [00:19<00:00,  1.50img/s, Cumulative Dice=[0.885552167892456], loss (batch)=0.576]
Epoch 3/100: 100%|██████████| 30/30 [00:20<00:00,  1.47img/s, Cumulative Dice=[0.9328572154045105], loss (batch)=0.433]
Epoch 4/100: 100%|██████████| 30/30 [00:20<00:00,  1.49img/s, Cumulative Dice=[0.9407636523246765], loss (batch)=0.346]
Epoch 5/100: 100%|██████████| 30/30 [00:20<00:00,  1.45img/s, Cumulative Dice=[0.9426813125610352], loss (batch)=0.337]
Epoch 6/100: 100%|███████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.937985360622406], loss (batch)=0.337]
Epoch 7/100: 100%|██████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.9434985518455505], loss (batch)=0.286]
Epoch 8/100: 100%|████████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.94794100522995], loss (batch)=0.264]
Epoch 9/100: 100%|██████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▄▃▃▂▁▂▃▁▁▁▃▃▃▃▃▄▃▃▃▃▄▄▄▄▅▅▅▆▅▄▄▄▅▅▆▇▆▇▆
Angle -20 Test Loss,█▄▃▃▂▁▂▄▁▁▂▃▃▃▄▄▄▄▃▄▃▄▄▄▄▅▅▅▆▅▅▄▅▅▅▆▇▆▇▆
Angle -5 Test Loss,█▄▃▂▂▁▂▃▁▁▂▃▃▃▃▃▄▃▃▃▃▃▄▄▄▅▅▅▆▅▅▄▄▅▅▆▆▆▆▆
Angle 0 Test Loss,█▄▂▂▁▁▂▂▁▁▂▂▂▃▃▃▃▃▃▃▃▃▄▃▄▄▅▅▆▅▄▄▄▅▅▆▆▅▆▆
Angle 10 Test Loss,█▄▃▃▂▁▂▃▁▁▁▃▃▃▃▃▄▃▂▃▃▃▄▃▄▅▅▅▆▅▄▄▄▅▅▆▆▆▆▆
Angle 20 Test Loss,█▄▃▃▂▁▂▃▁▂▁▃▃▃▄▃▄▄▃▃▃▄▄▄▄▅▅▅▆▅▅▄▄▅▅▆▇▆▇▆
Angle 5 Test Loss,█▄▂▂▂▁▂▃▁▁▁▂▃▃▃▃▄▃▂▃▃▃▄▃▄▅▅▅▆▅▄▄▄▅▅▆▆▆▆▆
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▂▁▁▂▃▅▄▆▄▂▃▇▆▂▆▆▂▆▆▅▆█▆▅▇▅▅▆▅▂▇▇▁▆▆▂▇▇▇▆
Test Equivariance Error,▁▁▁▁▁▂▂▂▂▂▂▃▂▅▃▃▄▂▃▂▃▃▃▃▅▃▄▆▅█▆▅▇▅▅█▆▆▆▆

0,1
Angle -10 Test Loss,0.66965
Angle -20 Test Loss,0.70768
Angle -5 Test Loss,0.65037
Angle 0 Test Loss,0.62634
Angle 10 Test Loss,0.65962
Angle 20 Test Loss,0.70682
Angle 5 Test Loss,0.64362
Equivariance Weight,100.0
Random Equivariance Error,0.51699
Test Equivariance Error,0.23546


Epoch 1/100: 100%|██████████| 30/30 [00:18<00:00,  1.58img/s, Cumulative Dice=[0.8785585761070251], loss (batch)=0.661]
Epoch 2/100: 100%|███████████| 30/30 [00:19<00:00,  1.50img/s, Cumulative Dice=[0.9014581441879272], loss (batch)=0.68]
Epoch 3/100: 100%|███████████| 30/30 [00:20<00:00,  1.46img/s, Cumulative Dice=[0.926384687423706], loss (batch)=0.442]
Epoch 4/100: 100%|██████████| 30/30 [00:20<00:00,  1.49img/s, Cumulative Dice=[0.9413245916366577], loss (batch)=0.369]
Epoch 5/100: 100%|██████████| 30/30 [00:20<00:00,  1.47img/s, Cumulative Dice=[0.9384891986846924], loss (batch)=0.354]
Epoch 6/100: 100%|██████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.9394712448120117], loss (batch)=0.336]
Epoch 7/100: 100%|██████████| 30/30 [00:20<00:00,  1.45img/s, Cumulative Dice=[0.9485604166984558], loss (batch)=0.308]
Epoch 8/100: 100%|██████████| 30/30 [00:20<00:00,  1.47img/s, Cumulative Dice=[0.9383193850517273], loss (batch)=0.291]
Epoch 9/100: 100%|███████████| 30/30 [00

0,1
Angle -10 Test Loss,█▄▂▁▂▁▂▁▁▂▂▂▂▂▃▃▂▂▃▂▃▂▃▄▃▄▅▅▄▅▄▅▅▅▅▅▅▅▆▆
Angle -20 Test Loss,█▅▂▁▂▂▂▂▁▂▃▂▂▃▃▃▂▃▄▂▃▃▃▄▃▅▅▅▄▅▄▅▅▅▆▆▅▅▆▆
Angle -5 Test Loss,█▄▂▁▁▁▂▁▁▂▂▂▂▂▃▃▂▂▃▂▃▂▃▄▃▄▅▅▄▅▄▅▅▅▅▅▅▅▆▆
Angle 0 Test Loss,█▄▂▂▁▁▂▁▁▂▂▂▂▂▂▂▂▂▃▂▂▂▃▃▃▄▄▅▄▅▄▅▄▄▅▅▅▅▅▆
Angle 10 Test Loss,█▄▂▁▂▁▂▁▁▂▂▂▂▂▃▃▂▂▃▂▂▂▃▃▃▄▄▅▄▅▄▅▅▄▅▅▅▅▅▆
Angle 20 Test Loss,█▄▂▁▂▁▂▂▁▂▃▂▂▃▃▃▂▃▃▂▃▂▃▄▃▄▅▅▄▅▄▅▅▅▅▅▅▅▆▆
Angle 5 Test Loss,█▄▂▁▁▁▂▁▁▂▂▂▂▂▃▂▂▂▃▂▂▂▃▄▃▄▄▅▄▅▄▅▅▄▅▅▅▅▆▆
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▂▁▁▁▁▂▁▁▃▁▂▆▂▁▃▃▁▃▅▅▅▃▄▅█▂▄▇▆▁█▇▂▅▇▂▄▆▅▆
Test Equivariance Error,▁▂▁▁▁▂▁▁▂▂▂▃▂▅▂▂▅▂▂▃▂▃▃▃▄▃▄▄▄▇▅▄▆▄▅█▅▅▄▅

0,1
Angle -10 Test Loss,0.67217
Angle -20 Test Loss,0.71573
Angle -5 Test Loss,0.65988
Angle 0 Test Loss,0.6401
Angle 10 Test Loss,0.66734
Angle 20 Test Loss,0.70889
Angle 5 Test Loss,0.65475
Equivariance Weight,100.0
Random Equivariance Error,0.57162
Test Equivariance Error,0.21504


Epoch 1/100: 100%|██████████| 30/30 [00:18<00:00,  1.59img/s, Cumulative Dice=[0.8738383054733276], loss (batch)=0.684]
Epoch 2/100: 100%|██████████| 30/30 [00:19<00:00,  1.52img/s, Cumulative Dice=[0.9070534110069275], loss (batch)=0.586]
Epoch 3/100: 100%|██████████| 30/30 [00:20<00:00,  1.48img/s, Cumulative Dice=[0.9309066534042358], loss (batch)=0.419]
Epoch 4/100: 100%|██████████| 30/30 [00:20<00:00,  1.49img/s, Cumulative Dice=[0.9371166229248047], loss (batch)=0.397]
Epoch 5/100: 100%|██████████| 30/30 [00:20<00:00,  1.47img/s, Cumulative Dice=[0.9393486380577087], loss (batch)=0.343]
Epoch 6/100: 100%|████████████| 30/30 [00:20<00:00,  1.49img/s, Cumulative Dice=[0.9431242942810059], loss (batch)=0.3]
Testing round, angle = -20:  20%|██████████▍                                         | 6/30 [00:01<00:06,  3.84batch/s]

In [None]:
for j in range(10):
    args['eqweight'] = 10
    args['augmented'] = 'rangle'
    model = UNet(args['in_channels'], args['classes'], **args).to(device=device)
    model = model.to(memory_format=torch.channels_last)
    train_loader = config_data(HeLa=True, Oxford=False, **args)
    train_model(model, device, train_loader, **args)

[34m[1mwandb[0m: Currently logged in as: [33mjjkjjk23[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/50: 100%|████████████| 30/30 [00:49<00:00,  1.65s/img, Cumulative Dice=[0.7585556507110596], loss (batch)=1.95]
Epoch 2/50: 100%|████████████| 30/30 [00:51<00:00,  1.71s/img, Cumulative Dice=[0.8912205696105957], loss (batch)=1.17]
Epoch 3/50: 100%|███████████| 30/30 [00:51<00:00,  1.72s/img, Cumulative Dice=[0.9097835421562195], loss (batch)=0.781]
Epoch 4/50: 100%|███████████| 30/30 [00:51<00:00,  1.72s/img, Cumulative Dice=[0.9316715002059937], loss (batch)=0.614]
Epoch 5/50: 100%|███████████| 30/30 [00:51<00:00,  1.72s/img, Cumulative Dice=[0.9398154020309448], loss (batch)=0.495]
Epoch 6/50: 100%|████████████| 30/30 [00:52<00:00,  1.74s/img, Cumulative Dice=[0.938301682472229], loss (batch)=0.474]
Epoch 7/50: 100%|███████████| 30/30 [00:52<00:00,  1.74s/img, Cumulative Dice=[0.9422001838684082], loss (batch)=0.392]
Epoch 8/50: 100%|███████████| 30/30 [00:51<00:00,  1.73s/img, Cumulative Dice=[0.9475051164627075], loss (batch)=0.327]
Epoch 9/50: 100%|████████████| 30/30 [00

0,1
Angle -10 Test Loss,█▆▄▃▂▁▁▂▂▂▁▁▂▂▁▁▂▂▁▂▁▂▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂
Angle -20 Test Loss,█▆▅▃▁▁▁▂▂▂▂▁▂▂▁▂▂▂▁▂▁▁▂▂▂▂▂▂▃▂▂▂▃▂▃▂▂▂▂▂
Angle -5 Test Loss,█▆▄▃▂▁▁▂▂▂▂▁▂▂▁▁▂▂▁▂▁▂▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂
Angle 0 Test Loss,█▆▄▃▂▂▁▂▂▂▁▁▁▂▁▁▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Angle 10 Test Loss,█▆▄▃▂▁▁▂▂▂▂▁▂▂▁▁▂▂▁▂▁▂▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂
Angle 20 Test Loss,█▆▅▃▁▁▁▂▂▂▂▁▂▂▁▁▂▂▁▂▁▁▁▂▂▂▂▂▃▂▂▂▂▂▃▂▂▁▂▂
Angle 5 Test Loss,█▆▄▃▂▁▁▂▂▂▂▁▂▂▁▁▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▇▂▂▁▁▁▁▁▁▁▁▁▁▁▁▃▂▂▂▃▁▁▂▁▃▄▁▄▄▄▁▁▄▄▂█▂▁▃▅
Test Equivariance Error,▅▃▃▃▂▄▄▂▃▄▁▄▄▅▄▄▅▅▄▄▅▆▆▅▅▅▄▆▆█▇▆▇▇▇▇▇▇▆█

0,1
Angle -10 Test Loss,0.50282
Angle -20 Test Loss,0.57662
Angle -5 Test Loss,0.46505
Angle 0 Test Loss,0.41795
Angle 10 Test Loss,0.49363
Angle 20 Test Loss,0.55111
Angle 5 Test Loss,0.45899
Equivariance Weight,10.0
Random Equivariance Error,0.34719
Test Equivariance Error,0.10776


Epoch 1/50: 100%|████████████| 30/30 [00:47<00:00,  1.59s/img, Cumulative Dice=[0.8837778568267822], loss (batch)=1.45]
Epoch 2/50: 100%|███████████| 30/30 [00:51<00:00,  1.70s/img, Cumulative Dice=[0.8986923694610596], loss (batch)=0.833]
Epoch 3/50: 100%|███████████| 30/30 [00:50<00:00,  1.70s/img, Cumulative Dice=[0.9230343699455261], loss (batch)=0.664]
Epoch 4/50: 100%|███████████| 30/30 [00:51<00:00,  1.71s/img, Cumulative Dice=[0.9395169615745544], loss (batch)=0.371]
Epoch 5/50: 100%|███████████| 30/30 [00:51<00:00,  1.70s/img, Cumulative Dice=[0.9401797652244568], loss (batch)=0.431]
Epoch 6/50: 100%|███████████| 30/30 [00:50<00:00,  1.69s/img, Cumulative Dice=[0.9423207640647888], loss (batch)=0.401]
Epoch 7/50: 100%|████████████| 30/30 [00:51<00:00,  1.71s/img, Cumulative Dice=[0.9411848783493042], loss (batch)=0.35]
Epoch 8/50: 100%|███████████| 30/30 [00:51<00:00,  1.71s/img, Cumulative Dice=[0.9476143717765808], loss (batch)=0.322]
Epoch 9/50: 100%|███████████| 30/30 [00:

0,1
Angle -10 Test Loss,█▆▄▃▁▁▂▂▂▁▁▂▂▃▂▂▂▂▁▂▂▂▂▂▂▂▂▃▂▃▂▃▂▂▃▂▂▂▃▂
Angle -20 Test Loss,█▆▄▃▁▁▂▃▃▂▁▂▂▃▂▂▃▂▁▂▂▂▃▃▂▂▃▃▂▄▂▃▂▂▃▂▂▂▄▂
Angle -5 Test Loss,█▅▄▃▁▁▂▂▂▁▁▂▂▂▁▂▂▁▁▂▂▂▂▂▂▂▂▃▂▃▂▃▂▂▃▂▂▂▃▂
Angle 0 Test Loss,█▆▄▃▂▁▂▁▂▁▁▂▁▂▁▁▂▁▁▂▂▂▂▂▂▂▂▃▂▃▂▃▂▂▃▂▂▂▃▂
Angle 10 Test Loss,█▆▄▃▁▁▂▂▂▁▁▂▂▂▂▂▂▁▁▂▂▂▂▂▂▂▂▃▂▃▂▃▂▂▃▂▂▂▃▂
Angle 20 Test Loss,█▆▄▃▁▂▂▃▃▂▁▂▂▃▂▂▃▂▁▂▂▂▃▃▂▂▃▃▂▃▂▃▂▂▃▂▂▃▄▂
Angle 5 Test Loss,█▅▄▃▁▁▁▂▂▁▁▂▂▂▁▁▂▁▁▂▂▂▂▂▂▂▂▃▂▃▂▃▂▂▃▂▂▂▃▂
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▃▁▁▁▁▁▁▁▁▁▁▁▅▃▆▆▅▅▅▆▆▆▅▁▄▆▂▅▃▇▅▆██▆▇▇▆▅█
Test Equivariance Error,▂▂▁▁▁▂▂▂▂▃▁▃▃▄▄▄▅▄▄▄▆▅▄▄▅▅▄▅▆█▅▆█▆▆▇▇▆▅▆

0,1
Angle -10 Test Loss,0.5072
Angle -20 Test Loss,0.56755
Angle -5 Test Loss,0.47556
Angle 0 Test Loss,0.43516
Angle 10 Test Loss,0.50854
Angle 20 Test Loss,0.56781
Angle 5 Test Loss,0.47451
Equivariance Weight,10.0
Random Equivariance Error,0.60995
Test Equivariance Error,0.10847


Epoch 1/50: 100%|████████████| 30/30 [00:49<00:00,  1.66s/img, Cumulative Dice=[0.8596620559692383], loss (batch)=1.66]
Epoch 2/50: 100%|████████████| 30/30 [00:52<00:00,  1.76s/img, Cumulative Dice=[0.8977052569389343], loss (batch)=1.04]
Epoch 3/50: 100%|███████████| 30/30 [00:52<00:00,  1.76s/img, Cumulative Dice=[0.9031875133514404], loss (batch)=0.771]
Epoch 4/50: 100%|███████████| 30/30 [00:52<00:00,  1.76s/img, Cumulative Dice=[0.9294345378875732], loss (batch)=0.576]
Epoch 5/50: 100%|███████████| 30/30 [00:52<00:00,  1.75s/img, Cumulative Dice=[0.9410277605056763], loss (batch)=0.518]
Epoch 6/50: 100%|███████████| 30/30 [00:52<00:00,  1.75s/img, Cumulative Dice=[0.9430646300315857], loss (batch)=0.463]
Epoch 7/50: 100%|███████████| 30/30 [00:52<00:00,  1.75s/img, Cumulative Dice=[0.9397327303886414], loss (batch)=0.483]
Epoch 8/50: 100%|███████████| 30/30 [00:52<00:00,  1.74s/img, Cumulative Dice=[0.9436158537864685], loss (batch)=0.393]
Epoch 9/50: 100%|████████████| 30/30 [00

0,1
Angle -10 Test Loss,█▆▅▃▂▂▂▂▂▂▁▁▂▃▂▁▁▂▂▂▁▂▁▂▂▂▂▂▂▃▂▂▂▂▂▂▃▂▃▂
Angle -20 Test Loss,█▆▅▃▃▂▂▂▂▂▁▁▂▃▂▂▁▂▂▂▁▂▂▂▂▃▃▂▂▃▂▂▃▂▂▃▃▂▃▃
Angle -5 Test Loss,█▆▅▃▂▂▂▂▂▂▁▁▂▂▂▁▁▁▂▂▁▂▁▂▂▂▂▂▂▃▂▂▂▂▂▂▃▂▃▂
Angle 0 Test Loss,█▆▅▄▂▂▂▁▂▂▁▁▂▂▂▁▁▁▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂
Angle 10 Test Loss,█▆▅▃▂▂▂▂▂▂▁▁▂▃▂▁▁▁▂▂▁▂▁▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂
Angle 20 Test Loss,█▆▅▃▃▂▂▂▂▂▁▁▂▃▂▂▁▂▂▂▁▂▁▂▂▂▂▂▂▃▂▂▂▂▂▃▂▂▃▂
Angle 5 Test Loss,█▆▅▃▂▂▂▂▂▂▁▁▂▂▂▁▁▁▂▂▁▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Equivariance Weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Random Equivariance Error,▄▂▁▁▁▁▁▁▁▁▁▃▁▂▂▁▂▁▆▂▃▃▁▁▂▃▁▂▃▄▂▂▅▂▂█▅█▅▇
Test Equivariance Error,▃▃▂▂▁▃▃▁▄▃▂▃▄▄▄▄▅▅▅▄▅▄▅▃▆▆▃▆▆█▆▆█▇▇▇▇▇▆█

0,1
Angle -10 Test Loss,0.52025
Angle -20 Test Loss,0.5973
Angle -5 Test Loss,0.48688
Angle 0 Test Loss,0.44185
Angle 10 Test Loss,0.51232
Angle 20 Test Loss,0.56129
Angle 5 Test Loss,0.48043
Equivariance Weight,10.0
Random Equivariance Error,0.55441
Test Equivariance Error,0.11012


Epoch 1/50: 100%|████████████| 30/30 [00:50<00:00,  1.70s/img, Cumulative Dice=[0.8719428777694702], loss (batch)=1.64]
Epoch 2/50: 100%|███████████████| 30/30 [00:53<00:00,  1.78s/img, Cumulative Dice=[0.9015082120895386], loss (batch)=1]
Epoch 3/50: 100%|███████████| 30/30 [00:53<00:00,  1.78s/img, Cumulative Dice=[0.9198138117790222], loss (batch)=0.689]
Epoch 4/50: 100%|███████████| 30/30 [00:53<00:00,  1.77s/img, Cumulative Dice=[0.9313530921936035], loss (batch)=0.539]
Epoch 5/50: 100%|████████████| 30/30 [00:53<00:00,  1.77s/img, Cumulative Dice=[0.9446693658828735], loss (batch)=0.45]
Epoch 6/50: 100%|███████████| 30/30 [00:52<00:00,  1.76s/img, Cumulative Dice=[0.9433648586273193], loss (batch)=0.426]
Epoch 7/50: 100%|███████████| 30/30 [00:52<00:00,  1.76s/img, Cumulative Dice=[0.9463836550712585], loss (batch)=0.374]
Epoch 8/50: 100%|███████████| 30/30 [00:52<00:00,  1.75s/img, Cumulative Dice=[0.9473505020141602], loss (batch)=0.323]
Epoch 9/50: 100%|████████████| 30/30 [00

In [10]:
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device=torch.device('cpu')
epsilon=0

sizeup = torchvision.transforms.Resize((550,550), interpolation=TF.InterpolationMode.NEAREST_EXACT)
sizedown=torchvision.transforms.Resize((450,450), interpolation=TF.InterpolationMode.NEAREST_EXACT)
resize=torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.NEAREST_EXACT)
deformation = torch.tensor([[0,0,0],[0,.5,0],[0,0,0]]).to(device=device, dtype=torch.float32)
deformation=torch.unsqueeze(deformation, dim=0)
deformation= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformation)
ydeformation =torch.full([1,512,512], 0, device=device, dtype=torch.float32)
deformation = torch.stack([deformation, ydeformation], dim=3)
deform = lambda tensor : TF.elastic_transform(tensor, deformation, TF.InterpolationMode.NEAREST, 0.0)

def rotate(angle):
    return lambda inputs : torchvision.transforms.functional.rotate(inputs, angle)


upscale=torchvision.transforms.Compose([sizeup,resize])
etransform=torchvision.transforms.Compose([sizedown,resize])
shrinkcrop=etrainer.compose(etrainer.pad24,sizedown)

def shift(x, shiftnum=1, axis=-1):
    x=torch.transpose(x, axis, -1)
    if shiftnum == 0:
        padded = x
    elif shiftnum > 0:
        #paddings = (0, shift, 0, 0, 0, 0)
        paddings = [0 for j in range(2*len(tuple(x.shape)))]
        paddings[1]=shiftnum
        paddings=tuple(paddings)
        padded = nn.functional.pad(x[..., shiftnum:], paddings)
    elif shiftnum < 0:
        #paddings = (-shift, 0, 0, 0, 0, 0)
        paddings = [0 for j in range(2*len(tuple(x.shape)))]
        paddings[0]=-shiftnum
        paddings=tuple(paddings)
        padded = nn.functional.pad(x[..., :shiftnum], paddings)
    else:
        raise ValueError
    return torch.transpose(padded, axis,-1)
def randshift(x):
    shiftnum = random.randint(-6,6)
    axis = random.randint(-2,-1)
    return shift(x, shiftnum, axis)
#This is for scaling
#efunctions=[[etransform, etransform, epsilon], [upscale,upscale,epsilon]] 
efunctions=[[torchvision.transforms.RandomRotation(10), torchvision.transforms.RandomRotation(10),epsilon]]
#efunctions=efunctions+[[lambda x : shift(x, shiftnum, axis), lambda x : shift(x, shiftnum, axis), epsilon] for shiftnum in range(-1,1,2) for axis in range(-1,1,2)]+[[etransform, etransform, epsilon], [upscale,upscale,epsilon]]
#efunctions += [[randshift, randshift, epsilon]]
#efunctions = efunctions+[[deform,deform,epsilon]]
#efunctions = efunctions + [[rotate(90), rotate(90),0]]
#efunctions = efunctions + [[torchvision.transforms.ElasticTransform(interpolation=TF.InterpolationMode.NEAREST), torchvision.transforms.ElasticTransform(interpolation=TF.InterpolationMode.NEAREST), epsilon]]
"""
for x in range(-1,1):
    for y in range(-1,1):
        for i in range(3):
            for j in range(3):
                for a in range(3):
                    for b in range(3):
                        deformationx = torch.tensor([[x if i==k else 0 for k in range(3)] if j==l else [0,0,0] for l in range(3)]).to(device=device, dtype=torch.float32)
                        deformationy = torch.tensor([[y if a==k else 0 for k in range(3)] if b==l else [0,0,0] for l in range(3)]).to(device=device, dtype=torch.float32)
                        deformationx=torch.unsqueeze(deformationx, dim=0)
                        deformationy=torch.unsqueeze(deformationy, dim=0)
                        deformationx= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformationx)
                        deformationy= torchvision.transforms.Resize((512,512), interpolation=TF.InterpolationMode.BICUBIC)(deformationy)
                        deformation = torch.stack([deformationx, deformationy], dim=3)
                        deform = lambda tensor : TF.elastic_transform(tensor, deformation, TF.InterpolationMode.NEAREST, 0.0)
                        efunctions.append([deform,deform,0])
"""
args2 = {'epochs' : 100,
        'batch_size' : 1,
        'amp' : True,
        'bilinear' : False,
        'classes' : 1,
        'learning_rate' : 1e-6,
        'load': False,
        #'load' : "C:\\Users\\jjkjj\\Equivariant\\EquivariantUNet\\bumbling-sponge-27_checkpoints\\checkpoint_epoch121.pth",
        'class weights' : [1,1,3],
        'epochbreaks' : False,
        'break_length' : 5,
        'etransforms' : efunctions,
        'equivariance_measure' : 'l1',
        'equivariant' : False,
        'eqerror' : False,
        'augmented' : 'rangle',
        'Linf' : False,
        'eqweight' : 100,
        'n' : 1,
        'debugging' : False,
        'in_channels' : 1,
        'wandb_project' : 'HeLa EUNet',
        'test_on_epoch_end' : True,
        'test augmented' : 'True no identity',
        'test augment' : 'fixed rotations',
        'save_checkpoint' : True,
        'eqweight_scheduler' : False,
        'eqweight_decay' : 1.1,
        'lr_scheduler' : 'cyclic',
        'min_lr' : 1e-9,
        'max_lr' : 1.5e-4,
        'product_loss' : False
       }
if args['equivariant'] != args['eqerror']:
    print('Equivariant and eqerror are different are you sure?')

model = UNet(args['in_channels'], args['classes'], **args).to(device=device)
model = model.to(memory_format=torch.channels_last)

#model = UNetLightning(args['in_channels'], args['classes'], **args).to(device=device)


if args['load']:
    state_dict = torch.load(args['load'], map_location=device)
    #del state_dict['mask_values']
    model.load_state_dict(state_dict)
    logging.info(f'Model loaded from {args["load"]}')
print(model.n)
print(model.Linf)
#print(model.state_dict())




1
False
