<a href="https://colab.research.google.com/github/benardt/ML/blob/main/mylib.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy
import matplotlib as mpl
import matplotlib.pyplot as plt
import torchvision
from skimage.util import compare_images
import torch.nn as nn
from prettytable import PrettyTable

In [None]:
def mytest(message):
  print("test import...",message)

In [None]:
def images_show(x,y,zoom,showaxis='on'):
    grid1 = torchvision.utils.make_grid(x)
    img1 = grid1.numpy().transpose((1, 2, 0))
    img1 = 1 - img1
    grid2 = torchvision.utils.make_grid(y)
    img2 = grid2.numpy().transpose((1, 2, 0))
    img2 = 1 - img2
    img2[numpy.all(img2==[0,0,0], axis=-1)] = (1,0,0)
    blend = compare_images(img2, img1, method='blend')

    plt.figure(figsize=(zoom,zoom))
    plt.imshow(blend)
    if showaxis == 'off':
      plt.axis('off')
    plt.ioff()
    plt.show()

In [None]:
def tensShow(T1,T2,mytitle):
    idx = 0
    img1 = T1[idx].cpu().detach().numpy()
    img1 = numpy.squeeze(img1)
    img1 = 1 - img1
    img2 = T2[idx].cpu().detach().numpy()
    img2 = numpy.squeeze(img2)

    mycmap = mpl.cm.gist_rainbow_r

    img3 = 1 * (img2 > 0.0)

    fig, ax = plt.subplots(nrows=1, ncols=3,figsize=(16,8))
    ax[1].imshow(img3, cmap='Greys', interpolation="bicubic")
    ax[0].imshow(img2, cmap=mycmap, interpolation="bicubic")
    ax[2].imshow(img1, cmap='Greys_r', interpolation="bicubic")
    ax[2].imshow(img2, cmap='jet', interpolation="bicubic", alpha=0.3)

    fig.tight_layout()
    fig.suptitle(mytitle)
    plt.show()

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        #inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

In [None]:
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params