# MyoPS 2020 Challenge


In [1]:
!pip install fastai2
!pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI



In [2]:
from fastai2.vision.all import *
from fastai2.vision.models import resnet34
from monai.losses import FocalLoss
import gc

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
path = "/content/drive/My Drive/miccai2020/myops/png"

In [5]:
np.random.seed(42)
ids = np.arange(101,126)
np.random.shuffle(ids)
ids = np.reshape(ids,(-1,5))
ids

array([[109, 117, 101, 124, 112],
       [110, 114, 102, 123, 106],
       [103, 113, 116, 104, 105],
       [121, 118, 122, 119, 125],
       [108, 111, 115, 120, 107]])

In [6]:
class AddMaskCodeMapping(Transform):
    "Add mapping of pixel value to class for a `TensorMask`"
    def __init__(self, mapping, codes=None):
        #print("init")
        self.mapping = mapping
        self.codes = codes
        if codes is not None: self.vocab,self.c = codes,len(codes)

    def encodes(self, o:PILMask):
        #print("encodes")
        mo = ToTensor()(o)
        mo = mo.to(dtype=torch.long)
        mo = self.mapping.index_select(0,mo.flatten()).reshape(*mo.shape)
        mo = PILMask.create(mo.to(dtype=torch.uint8))
        return mo
    
    def decodes(self, o:TensorMask):
        # decoding of inputs works out of the box, but get_preds are not properly decoded
        if len(o.shape) > 2:
            o = o.argmax(dim=0)
        if self.codes is not None: o._meta = {'codes': self.codes}
        return o

In [7]:
def MappedMaskBlock(mapping,codes=None):
    "A `TransformBlock` for segmentation masks, with mapping of pixel values to classes, potentially with `codes`"
    return TransformBlock(type_tfms=PILMask.create, item_tfms=AddMaskCodeMapping(mapping=mapping,codes=codes), batch_tfms=IntToFloatTensor)

In [8]:
def getMappedMaskBlock(predefined_mapping_name):
    predefined_mappings = {
        'full': ([0,1,2,3,4,5],['bg','lv','my','rv','ed','sc']),
        'edOnly': ([0,0,0,0,1,0],['bg','ed']),
        'edScCombined': ([0,0,0,0,1,1],['bg','edSc']),
        'scOnly': ([0,0,0,0,0,1],['bg','sc']),
        'edScOnly': ([0,0,0,0,1,2],['bg','ed','sc']),
    }
    mapping,codes = predefined_mappings[predefined_mapping_name]
    return MappedMaskBlock(mapping = torch.LongTensor(mapping), codes=codes)

In [9]:
def getMyopsDls(val_ids, mapping_name="full", images="images"):
    mmb = getMappedMaskBlock(mapping_name)
    myopsData = DataBlock(blocks=(ImageBlock, mmb),#['bg','lv','my','rv','ed','sc'])),
        get_items=get_image_files,
        splitter=FuncSplitter(lambda o: int(o.name.split("-")[0]) in val_ids),
        get_y=lambda o: str(o).replace(images,"masks"),
        item_tfms=CropPad(256),
        batch_tfms=aug_transforms(max_rotate=90,pad_mode="zeros"))
    dls = myopsData.dataloaders(f'{path}/{images}',num_workers=4,batch_size=12)
    dls[1].bs = 12
    return dls

In [10]:
def multi_dice(input:Tensor, targs:Tensor, class_id=0, inverse=False)->Tensor:
    n = targs.shape[0]
    input = input.argmax(dim=1).view(n,-1)
    # replace all with class_id with 1 all else with 0 to have binary case
    output = (input == class_id).float()
    # same for targs
    targs = (targs.view(n,-1) == class_id).float()
    if inverse:
        output = 1 - output
        targs = 1 - targs
    intersect = (output * targs).sum(dim=1).float()
    union = (output+targs).sum(dim=1).float()
    res = 2. * intersect / union
    res[torch.isnan(res)] = 1
    return res.mean()

def diceFG(input, targs): return multi_dice(input,targs,class_id=1)
def diceLV(input, targs): return multi_dice(input,targs,class_id=1)
def diceMY(input, targs): return multi_dice(input,targs,class_id=2)
def diceRV(input, targs): return multi_dice(input,targs,class_id=3)
def diceEd(input, targs): return multi_dice(input,targs,class_id=4)
def diceSc(input, targs): return multi_dice(input,targs,class_id=5)
dices = [diceLV,diceMY,diceRV,diceEd,diceSc]

In [11]:
def myFocal(weights=[.2,.8]):
    monaiFocal = FocalLoss(weight=torch.Tensor(weights),reduction='mean')
    return lambda input,target: monaiFocal(input, target.unsqueeze(1))

# Generic function

In [12]:
def standard_training(
    mapping="full",
    images="images",
    cleanup=True,
    metrics=dices,
    weights=[1.0/6,1.0/6,1.0/6,1.0/6,1.0/6,1.0/6],
    cv=0,
    name="standard_training",
    use_focal_loss=True,
    path='/content/drive/My Drive/miccai2020/myops',
    save=False
    ):
    my_loss = None
    if(use_focal_loss):
        my_loss = myFocal(weights)
    learn = unet_learner(
        getMyopsDls(ids[cv], mapping, images),
        resnet34,
        path=path,
        loss_func=my_loss,
        metrics=[foreground_acc, *metrics],
        cbs=[CSVLogger(f'{path}/logs/{name}.csv',append=True)]
    )
    learn.fine_tune(20, freeze_epochs=10, base_lr=1e-3)
    if save:
      learn.save(name)
    if not cleanup:
        return learn
    else:
        del learn
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
for cv in range(5):
  standard_training(name=f"multi_ce_cv{cv}", cv=cv, use_focal_loss=False)
  standard_training(name=f"multi_balanced_cv{cv}", cv=cv)
  standard_training(name=f"multi_pathoMyo.2_cv{cv}", cv=cv, weights=[.4/3,.4/3,.2,.4/3,.2,.2])
  standard_training(name=f"multi_pathoMyo.3_cv{cv}", cv=cv, weights=[.1/3,.1/3,.3,.1/3,.3,.3])
  standard_training(name=f"multi_patho.2_cv{cv}", cv=cv, weights=[.6/4,.6/4,.6/4,.6/4,.2,.2])
  standard_training(name=f"multi_patho.35_cv{cv}", cv=cv, weights=[.3/4,.3/4,.3/4,.3/4,.35,.35])
  standard_training(name=f"multi_patho.49_cv{cv}", cv=cv, weights=[.02/4,.02/4,.02/4,.02/4,.49,.49])
  standard_training(name=f"multi_scar.2_cv{cv}", cv=cv, weights=[.8/5,.8/5,.8/5,.8/5,.8/5,.2])
  standard_training(name=f"multi_scar.4_cv{cv}", cv=cv, weights=[.6/5,.6/5,.6/5,.6/5,.6/5,.4])
  standard_training(name=f"multi_scar.6_cv{cv}", cv=cv, weights=[.4/5,.4/5,.4/5,.4/5,.4/5,.6])
  standard_training(name=f"multi_scar.8_cv{cv}", cv=cv, weights=[.2/5,.2/5,.2/5,.2/5,.2/5,.8])
  standard_training(name=f"multi_scar.99_cv{cv}", cv=cv, weights=[.01/5,.01/5,.01/5,.01/5,.01/5,.99])
  standard_training(name=f"multi_edema.2_cv{cv}", cv=cv, weights=[.8/5,.8/5,.8/5,.8/5,.2,.8/5])
  standard_training(name=f"multi_edema.4_cv{cv}", cv=cv, weights=[.6/5,.6/5,.6/5,.6/5,.4,.6/5])
  standard_training(name=f"multi_edema.6_cv{cv}", cv=cv, weights=[.4/5,.4/5,.4/5,.4/5,.6,.4/5])
  standard_training(name=f"multi_edema.8_cv{cv}", cv=cv, weights=[.2/5,.2/5,.2/5,.2/5,.8,.2/5])
  standard_training(name=f"multi_edema.99_cv{cv}", cv=cv, weights=[.01/5,.01/5,.01/5,.01/5,.99,.01/5])
  standard_training(name=f"lge_scarOnly.5_cv{cv}", cv=cv, mapping="scOnly", images="LGE", metrics=[diceFG], weights=[0.5,0.5])
  standard_training(name=f"lge_scarOnly.8_cv{cv}", cv=cv, mapping="scOnly", images="LGE", metrics=[diceFG], weights=[0.2,0.8])
  standard_training(name=f"t2_edemaOnly.5_cv{cv}", cv=cv, mapping="edOnly", images="T2", metrics=[diceFG], weights=[0.5,0.5])
  standard_training(name=f"t2_edemaOnly.8_cv{cv}", cv=cv, mapping="edOnly", images="T2", metrics=[diceFG], weights=[0.2,0.8])

epoch,train_loss,valid_loss,foreground_acc,diceLV,diceMY,diceRV,diceEd,diceSc,time
0,0.718315,0.98065,0.006446,0.029914,0.0,0.004842,0.2,0.12,00:21


epoch,train_loss,valid_loss,foreground_acc,diceLV,diceMY,diceRV,diceEd,diceSc,time
0,0.436957,0.254848,0.292792,0.626406,0.036401,0.329156,0.2,0.12,00:17
1,0.332487,0.199177,0.595028,0.68956,0.57215,0.5198,0.2,0.122796,00:16


epoch,train_loss,valid_loss,foreground_acc,diceLV,diceMY,diceRV,diceEd,diceSc,time
