# <font color="blue">**CS685  : Project**</font> 

###Active Learning for optimal sample selection

In [None]:
## mounting the google drive to load the data from directory
# from google.colab import drive
# drive.mount('/content/drive')
from google.colab import drive
drive.mount('/content/drive', force_remount=True) 

In [None]:
import tensorflow as tf
tf.test.gpu_device_name()

In [None]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

##Utils

In [None]:
from __future__ import print_function
import math
import os
import random
import copy
import scipy
import imageio
import string
import numpy as np
from datetime import datetime
import pytz
from skimage.transform import resize
try:  # SciPy >= 0.19
    from scipy.special import comb
except ImportError:
    from scipy.misc import comb

def bernstein_poly(i, n, t):
    """
     The Bernstein polynomial of n, i as a function of t
    """

    return comb(n, i) * ( t**(n-i) ) * (1 - t)**i

def bezier_curve(points, nTimes=1000):
    """
       Given a set of control points, return the
       bezier curve defined by the control points.

       Control points should be a list of lists, or list of tuples
       such as [ [1,1], 
                 [2,3], 
                 [4,5], ..[Xn, Yn] ]
        nTimes is the number of time steps, defaults to 1000

        See http://processingjs.nihongoresources.com/bezierinfo/
    """

    nPoints = len(points)
    xPoints = np.array([p[0] for p in points])
    yPoints = np.array([p[1] for p in points])

    t = np.linspace(0.0, 1.0, nTimes)

    polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)   ])
    
    xvals = np.dot(xPoints, polynomial_array)
    yvals = np.dot(yPoints, polynomial_array)

    return xvals, yvals

def data_augmentation(x, y, prob=0.5):
    # augmentation by flipping
    cnt = 3
    while random.random() < prob and cnt > 0:
        degree = random.choice([0, 1, 2])
        x = np.flip(x, axis=degree)
        y = np.flip(y, axis=degree)
        cnt = cnt - 1

    return x, y

def nonlinear_transformation(x, prob=0.5):
    if random.random() >= prob:
        return x
    points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]]
    xpoints = [p[0] for p in points]
    ypoints = [p[1] for p in points]
    xvals, yvals = bezier_curve(points, nTimes=100000)
    if random.random() < 0.5:
        # Half change to get flip
        xvals = np.sort(xvals)
    else:
        xvals, yvals = np.sort(xvals), np.sort(yvals)
    nonlinear_x = np.interp(x, xvals, yvals)
    return nonlinear_x

def local_pixel_shuffling(x, prob=0.5):
    if random.random() >= prob:
        return x
    image_temp = copy.deepcopy(x)
    orig_image = copy.deepcopy(x)
    _, img_rows, img_cols = x.shape
    num_block = 10000
    for _ in range(num_block):
        block_noise_size_x = random.randint(1, img_rows//10)
        block_noise_size_y = random.randint(1, img_cols//10)
        noise_x = random.randint(0, img_rows-block_noise_size_x)
        noise_y = random.randint(0, img_cols-block_noise_size_y)
        window = orig_image[0, noise_x:noise_x+block_noise_size_x, 
                               noise_y:noise_y+block_noise_size_y, 
                           ]
        window = window.flatten()
        np.random.shuffle(window)
        window = window.reshape((block_noise_size_x, 
                                 block_noise_size_y))
        image_temp[0, noise_x:noise_x+block_noise_size_x, 
                      noise_y:noise_y+block_noise_size_y] = window
    local_shuffling_x = image_temp

    return local_shuffling_x

def image_in_painting(x):
    _, img_rows, img_cols = x.shape
    cnt = 5
    while cnt > 0 and random.random() < 0.95:
        block_noise_size_x = random.randint(img_rows//6, img_rows//3)
        block_noise_size_y = random.randint(img_cols//6, img_cols//3)
        noise_x = random.randint(3, img_rows-block_noise_size_x-3)
        noise_y = random.randint(3, img_cols-block_noise_size_y-3)
        x[:, 
          noise_x:noise_x+block_noise_size_x, 
          noise_y:noise_y+block_noise_size_y] = np.random.rand(block_noise_size_x, 
                                                               block_noise_size_y, ) * 1.0
        cnt -= 1
    return x

def image_out_painting(x):
    _, img_rows, img_cols = x.shape
    image_temp = copy.deepcopy(x)
    x = np.random.rand(x.shape[0], x.shape[1], x.shape[2], ) * 1.0
    block_noise_size_x = img_rows - random.randint(3*img_rows//7, 4*img_rows//7)
    block_noise_size_y = img_cols - random.randint(3*img_cols//7, 4*img_cols//7)
    noise_x = random.randint(3, img_rows-block_noise_size_x-3)
    noise_y = random.randint(3, img_cols-block_noise_size_y-3)
    x[:, 
      noise_x:noise_x+block_noise_size_x, 
      noise_y:noise_y+block_noise_size_y] = image_temp[:, noise_x:noise_x+block_noise_size_x, 
                                                       noise_y:noise_y+block_noise_size_y]
    cnt = 4
    while cnt > 0 and random.random() < 0.95:
        block_noise_size_x = img_rows - random.randint(3*img_rows//7, 4*img_rows//7)
        block_noise_size_y = img_cols - random.randint(3*img_cols//7, 4*img_cols//7)
        noise_x = random.randint(3, img_rows-block_noise_size_x-3)
        noise_y = random.randint(3, img_cols-block_noise_size_y-3)
        x[:, 
          noise_x:noise_x+block_noise_size_x, 
          noise_y:noise_y+block_noise_size_y] = image_temp[:, noise_x:noise_x+block_noise_size_x, 
                                                           noise_y:noise_y+block_noise_size_y]
        cnt -= 1
    return x
                


def generate_pair(img, batch_size, config, status="test"):
    img_rows, img_cols = img.shape[2], img.shape[3]
    while True:
        index = [i for i in range(img.shape[0])]
        random.shuffle(index)
        y = img[index[:batch_size]]
        x = copy.deepcopy(y)
        for n in range(batch_size):
            
            # Autoencoder
            x[n] = copy.deepcopy(y[n])
            
            # Flip
            x[n], y[n] = data_augmentation(x[n], y[n], config.flip_rate)

            # Local Shuffle Pixel
            x[n] = local_pixel_shuffling(x[n], prob=config.local_rate)
            
            # Apply non-Linear transformation with an assigned probability
            x[n] = nonlinear_transformation(x[n], config.nonlinear_rate)
            
            # Inpainting & Outpainting
            if random.random() < config.paint_rate:
                if random.random() < config.inpaint_rate:
                    # Inpainting
                    x[n] = image_in_painting(x[n])
                else:
                    # Outpainting
                    x[n] = image_out_painting(x[n])

        # Save sample images module
        if config.save_samples is not None and status == "train" and random.random() < 0.01:
            n_sample = random.choice( [i for i in range(config.batch_size)] )
            final_sample = np.concatenate((x[n_sample,0,:,:], y[n_sample,0,:,:]), axis=1)
            final_sample = final_sample * 255.0
            final_sample = final_sample.astype(np.uint8)
            file_name = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(10)])+'.'+config.save_samples
            imageio.imwrite(os.path.join(config.sample_path, config.exp_name, file_name), final_sample)

        yield (x, y)

        


#UNET

In [None]:
""" Full assembly of the parts to form the complete network """

""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.conv(x))



class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
    
        #per_out=[]
        
        x1 = self.inc(x)
        #per_out.append(x1) # conv1
        
        x2 = self.down1(x1)
        #per_out.append(x2) # down1
        
        x3 = self.down2(x2)
        #per_out.append(x3) # down2
        
        x4 = self.down3(x3)
        #per_out.append(x4) # down3
        
        x5 = self.down4(x4)
        #per_out.append(x5) # down4
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        return logits

    
class UNet_hidden(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet_hidden, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
    
        #per_out=[]
        
        x1 = self.inc(x)
        #per_out.append(x1) # conv1
        
        x2 = self.down1(x1)
        #per_out.append(x2) # down1
        
        x3 = self.down2(x2)
        #per_out.append(x3) # down2
        
        x4 = self.down3(x3)
        #per_out.append(x4) # down3
        
        x5 = self.down4(x4)
        #per_out.append(x5) # down4
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        return logits, x5

##Swin_UNET

In [None]:
# import torch
# import torch.nn as nn
# # from torchvision.models.utils import load_state_dict_from_url

# # Swin Transformer Encoder Block
# class SwinTransformerEncoderBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads, window_size, shift_size, mlp_ratio=4.0):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attn = nn.MultiheadAttention(embed_dim, num_heads)
#         self.norm2 = nn.LayerNorm(embed_dim)
#         self.mlp = nn.Sequential(
#             nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
#             nn.GELU(),
#             nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
#         )

#         # Window partitioning and shifting
#         self.window_size = window_size
#         self.shift_size = shift_size

#     def forward(self, x):
#         # Apply layer normalization
#         x = self.norm1(x)

#         # Apply self-attention
#         x = x.permute(2, 0, 1)
#         x, _ = self.attn(x, x, x)
#         x = x.permute(1, 2, 0)

#         # Apply residual connection and layer normalization
#         x = x + x.permute(0, 2, 1)
#         x = self.norm2(x)

#         # Apply MLP
#         y = self.mlp(x)

#         # Partition into non-overlapping windows and shift
#         B, N, C = y.shape
#         h = self.window_size
#         w = N // h
#         y = y.view(B, h, w, C)
#         y = y.permute(0, 3, 1, 2)
#         y = torch.nn.functional.pad(y, (0, 0, 0, 0, self.shift_size // 2, self.shift_size // 2), mode="constant")
#         y = y.reshape(B, C, h * (w + 2 * (self.shift_size // 2)))
#         y = y.permute(0, 2, 1)

#         # Apply residual connection
#         y = y + x

#         return y

# # Swin UNet Model
# class SwinUNet(nn.Module):
#     def __init__(self, input_channels, num_classes):
#         super().__init__()

#         # Swin Transformer Encoder Blocks
#         self.enc1 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0)
#         self.enc2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0)
#         self.enc3 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0)
#         self.enc4 = SwinTransformerEncoderBlock(embed_dim=512, num_heads=16, window_size=7, shift_size=0)

#         # Swin Transformer Decoder Blocks
#         self.dec1 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0)
#         self.dec2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0)
#         self.dec3 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0)

#         # Final convolutional layer for segmentation output
#         self.final_conv = nn.Conv2d(96, num_classes, kernel_size=1, stride=1, padding=0)

#     def forward(self, x):
#         # Encoder
#         x1 = self.enc1(x)
#         x2 = self.enc2(x1)
#         x3 = self.enc3(x2)
#         x4 = self.enc4(x3)

#         # Decoder
#         y1 = self.dec1(x4) + nn.functional.interpolate(x4, scale_factor=2, mode="nearest")
#         y2 = self.dec2(y1) + nn.functional.interpolate(y1, scale_factor=2, mode="nearest")
#         y3 = self.dec3(y2) + nn.functional.interpolate(y2, scale_factor=2, mode="nearest")

#         # Output
#         y = self.final_conv(y3)
#         return y


In [None]:
import torch
import torch.nn as nn
# from torchvision.models.utils import load_state_dict_from_url

# Swin Transformer Encoder Block
class SwinTransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, shift_size, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim,(256, 256))
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim,(256, 256))
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
        )

        # Window partitioning and shifting
        self.window_size = window_size
        self.shift_size = shift_size

    def forward(self, x):
        # Apply layer normalization
        x = self.norm1(x)

        # Apply self-attention
        x = x.permute(0, 2, 3, 1)
        B, H, W, C = x.shape
        x = x.reshape(B * H * W, C, 1)
        x, _ = self.attn(x, x, x)
        x = x.reshape(B, H, W, C)
        x = x.permute(0, 3, 1, 2)

        # Apply residual connection and layer normalization
        x = x + x.permute(0, 2, 3, 1)
        x = self.norm2(x)

        # Apply MLP
        y = self.mlp(x)

        # Partition into non-overlapping windows and shift
        B, N, C = y.shape
        h = self.window_size
        w = N // h
        y = y.view(B, h, w, C)
        y = y.permute(0, 3, 1, 2)
        y = torch.nn.functional.pad(y, (0, 0, 0, 0, self.shift_size // 2, self.shift_size // 2), mode="constant")
        y = y.reshape(B, C, h * (w + 2 * (self.shift_size // 2)))
        y = y.permute(0, 2, 1)

        # Apply residual connection
        y = y + x

        return y

# Swin UNet Model
class SwinUNet(nn.Module):
    def __init__(self, input_channels, num_classes):
        super().__init__()

        # Swin Transformer Encoder Blocks
        self.enc1 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.enc2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.enc3 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.enc4 = SwinTransformerEncoderBlock(embed_dim=512, num_heads=16, window_size=7, shift_size=0,mlp_ratio=4.0)

        # Swin Transformer Decoder Blocks
        self.dec1 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.dec2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.dec3 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0,mlp_ratio=4.0)

        # Final convolutional layer for segmentation output
        self.final_conv = nn.Conv2d(96, num_classes, kernel_size=1, stride=1, padding=0)

        # Modify the input channel of the first convolutional layer
        self.conv_input = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # Encoder
        x = self.conv_input(x)
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)

        # Decoder
        y1 = self.dec1(x4) + nn.functional.interpolate(x4, scale_factor=2, mode="nearest")
        y2 = self.dec2(y1) + nn.functional.interpolate(y1, scale_factor=2, mode="nearest")
        y3 = self.dec3(y2) + nn.functional.interpolate(y2, scale_factor=2, mode="nearest")

        # Output
        y = self.final_conv(y3)
        return y


##Config_clusters

In [None]:
import os
import shutil
from datetime import datetime
import pytz

class models_genesis_config:
    model = "Unet2D"
    suffix = "genesis_chest_ct"
    exp_name = model + "-" + suffix
    
    # data
    data = "/mnt/dataset/shared/zongwei/LUNA16/Self_Learning_Cubes" # not use
    scale = 32
    input_rows = 256
    input_cols = 256
    input_deps = 1
    nb_class = 1

    # image deformation
    nonlinear_rate = 0.9
    paint_rate = 0.9
    outpaint_rate = 0.8
    inpaint_rate = 1.0 - outpaint_rate
    local_rate = 0.5
    flip_rate = 0.4
    
    # logs
    # model_dir = "../SSLModel/Reuslts/pretrained_weights"
    model_dir = "/content/drive/MyDrive/Spring_research_2023/SSLModel/Reuslts/pretrained_weights"
    timenow = datetime.strftime(datetime.now(pytz.timezone('Asia/Singapore')), '%Y-%m-%d_%H-%M-%S')
    model_path = os.path.join(model_dir,timenow)
    print('Model path: ',model_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
        
    logs_path = os.path.join(model_path, "Logs")
    print('log path: ',logs_path)
    if not os.path.exists(logs_path):
        os.makedirs(logs_path)
        
    shotdir = os.path.join(model_path, 'snapshot')
    print('snapshot path: ',shotdir)
    if not os.path.exists(shotdir):
        os.makedirs(shotdir)
    
    # model pre-training
    verbose = 1
    weights = os.path.join(model_path,'ISIC_Unsup.pt')
    batch_size = 1
    optimizer = "sgd"
    workers = 10
    max_queue_size = workers * 4
    save_samples = "png"
    nb_epoch = 10000
    patience = 100
    lr = 0.01
    
    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")


##Data Loader

In [None]:
from torch.utils.data import Dataset
import os
import numpy as np
import re
import imgaug.augmenters as iaa
#from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from torchvision import transforms
import cv2
import random
from tqdm import tqdm

import matplotlib.pyplot as plt

random.seed(1)


def Dataset_Loader(path, img_size):
    print('\nLoading dataset...\n')
    read_imgs = np.load(path)
    rows = img_size[0]
    cols = img_size[1]
    
    images = np.ndarray((read_imgs.shape[0], read_imgs.shape[-1], rows, cols), dtype=float)
    for i in range(read_imgs.shape[0]):
        img = cv2.resize(read_imgs[i, 0], (cols, rows), interpolation=cv2.INTER_CUBIC)
        images[i, 0, :, :] = img/255.
    return images

if __name__ == "__main__":
    data_path_train = '/content/drive/MyDrive/Spring_research_2023/data/GrayData'
    trainpath = data_path_train + '/imgs_train.npy'
    
    dataset = Dataset_Loader(trainpath,[256,256])

#PART 2: Clustering

##Soft_dtw_cuda

In [None]:
# # Code by Maghoumi/pytorch-softdtw-cuda
# # https://github.com/Maghoumi/pytorch-softdtw-cuda

# # MIT License
# #
# # Copyright (c) 2020 Mehran Maghoumi
# #
# # Permission is hereby granted, free of charge, to any person obtaining a copy
# # of this software and associated documentation files (the "Software"), to deal
# # in the Software without restriction, including without limitation the rights
# # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# # copies of the Software, and to permit persons to whom the Software is
# # furnished to do so, subject to the following conditions:
# #
# # The above copyright notice and this permission notice shall be included in all
# # copies or substantial portions of the Software.
# #
# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# # SOFTWARE.
# # ----------------------------------------------------------------------------------------------------------------------

# import numpy as np
# import torch
# import torch.cuda
# from numba import jit
# from torch.autograd import Function
# from numba import cuda
# import math

# # ----------------------------------------------------------------------------------------------------------------------
# @cuda.jit
# def compute_softdtw_cuda(D, gamma, bandwidth, max_i, max_j, n_passes, R):
#     """
#     :param seq_len: The length of the sequence (both inputs are assumed to be of the same size)
#     :param n_passes: 2 * seq_len - 1 (The number of anti-diagonals)
#     """
#     # Each block processes one pair of examples
#     b = cuda.blockIdx.x
#     # We have as many threads as seq_len, because the most number of threads we need
#     # is equal to the number of elements on the largest anti-diagonal
#     tid = cuda.threadIdx.x

#     # Compute I, J, the indices from [0, seq_len)

#     # The row index is always the same as tid
#     I = tid

#     inv_gamma = 1.0 / gamma

#     # Go over each anti-diagonal. Only process threads that fall on the current on the anti-diagonal
#     for p in range(n_passes):

#         # The index is actually 'p - tid' but need to force it in-bounds
#         J = max(0, min(p - tid, max_j - 1))

#         # For simplicity, we define i, j which start from 1 (offset from I, J)
#         i = I + 1
#         j = J + 1

#         # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds
#         if I + J == p and (I < max_i and J < max_j):
#             # Don't compute if outside bandwidth
#             if not (abs(i - j) > bandwidth > 0):
#                 r0 = -R[b, i - 1, j - 1] * inv_gamma
#                 r1 = -R[b, i - 1, j] * inv_gamma
#                 r2 = -R[b, i, j - 1] * inv_gamma
#                 rmax = max(max(r0, r1), r2)
#                 rsum = math.exp(r0 - rmax) + math.exp(r1 - rmax) + math.exp(r2 - rmax)
#                 softmin = -gamma * (math.log(rsum) + rmax)
#                 R[b, i, j] = D[b, i - 1, j - 1] + softmin

#         # Wait for other threads in this block
#         cuda.syncthreads()

# # ----------------------------------------------------------------------------------------------------------------------
# @cuda.jit
# def compute_softdtw_backward_cuda(D, R, inv_gamma, bandwidth, max_i, max_j, n_passes, E):
#     k = cuda.blockIdx.x
#     tid = cuda.threadIdx.x

#     # Indexing logic is the same as above, however, the anti-diagonal needs to
#     # progress backwards
#     I = tid

#     for p in range(n_passes):
#         # Reverse the order to make the loop go backward
#         rev_p = n_passes - p - 1

#         # convert tid to I, J, then i, j
#         J = max(0, min(rev_p - tid, max_j - 1))

#         i = I + 1
#         j = J + 1

#         # Only compute if element[i, j] is on the current anti-diagonal, and also is within bounds
#         if I + J == rev_p and (I < max_i and J < max_j):

#             if math.isinf(R[k, i, j]):
#                 R[k, i, j] = -math.inf

#             # Don't compute if outside bandwidth
#             if not (abs(i - j) > bandwidth > 0):
#                 a = math.exp((R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) * inv_gamma)
#                 b = math.exp((R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) * inv_gamma)
#                 c = math.exp((R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) * inv_gamma)
#                 E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c

#         # Wait for other threads in this block
#         cuda.syncthreads()

# # ----------------------------------------------------------------------------------------------------------------------
# class _SoftDTWCUDA(Function):
#     """
#     CUDA implementation is inspired by the diagonal one proposed in https://ieeexplore.ieee.org/document/8400444:
#     "Developing a pattern discovery method in time series data and its GPU acceleration"
#     """

#     @staticmethod
#     def forward(ctx, D, gamma, bandwidth):
#         dev = D.device
#         dtype = D.dtype
#         gamma = torch.cuda.FloatTensor([gamma])
#         bandwidth = torch.cuda.FloatTensor([bandwidth])

#         B = D.shape[0]
#         N = D.shape[1]
#         M = D.shape[2]
#         threads_per_block = max(N, M)
#         n_passes = 2 * threads_per_block - 1

#         # Prepare the output array
#         R = torch.ones((B, N + 2, M + 2), device=dev, dtype=dtype) * math.inf
#         R[:, 0, 0] = 0

#         # Run the CUDA kernel.
#         # Set CUDA's grid size to be equal to the batch size (every CUDA block processes one sample pair)
#         # Set the CUDA block size to be equal to the length of the longer sequence (equal to the size of the largest diagonal)
#         compute_softdtw_cuda[B, threads_per_block](cuda.as_cuda_array(D.detach()),
#                                                    gamma.item(), bandwidth.item(), N, M, n_passes,
#                                                    cuda.as_cuda_array(R))
#         ctx.save_for_backward(D, R, gamma, bandwidth)
#         return R[:, -2, -2]

#     @staticmethod
#     def backward(ctx, grad_output):
#         dev = grad_output.device
#         dtype = grad_output.dtype
#         D, R, gamma, bandwidth = ctx.saved_tensors

#         B = D.shape[0]
#         N = D.shape[1]
#         M = D.shape[2]
#         threads_per_block = max(N, M)
#         n_passes = 2 * threads_per_block - 1

#         D_ = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
#         D_[:, 1:N + 1, 1:M + 1] = D

#         R[:, :, -1] = -math.inf
#         R[:, -1, :] = -math.inf
#         R[:, -1, -1] = R[:, -2, -2]

#         E = torch.zeros((B, N + 2, M + 2), dtype=dtype, device=dev)
#         E[:, -1, -1] = 1

#         # Grid and block sizes are set same as done above for the forward() call
#         compute_softdtw_backward_cuda[B, threads_per_block](cuda.as_cuda_array(D_),
#                                                             cuda.as_cuda_array(R),
#                                                             1.0 / gamma.item(), bandwidth.item(), N, M, n_passes,
#                                                             cuda.as_cuda_array(E))
#         E = E[:, 1:N + 1, 1:M + 1]
#         return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None


# # ----------------------------------------------------------------------------------------------------------------------
# #
# # The following is the CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw
# # Credit goes to Kanru Hua.
# # I've added support for batching and pruning.
# #
# # ----------------------------------------------------------------------------------------------------------------------
# @jit(nopython=True)
# def compute_softdtw(D, gamma, bandwidth):
#     B = D.shape[0]
#     N = D.shape[1]
#     M = D.shape[2]
#     R = np.ones((B, N + 2, M + 2)) * np.inf
#     R[:, 0, 0] = 0
#     for b in range(B):
#         for j in range(1, M + 1):
#             for i in range(1, N + 1):

#                 # Check the pruning condition
#                 if 0 < bandwidth < np.abs(i - j):
#                     continue

#                 r0 = -R[b, i - 1, j - 1] / gamma
#                 r1 = -R[b, i - 1, j] / gamma
#                 r2 = -R[b, i, j - 1] / gamma
#                 rmax = max(max(r0, r1), r2)
#                 rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
#                 softmin = - gamma * (np.log(rsum) + rmax)
#                 R[b, i, j] = D[b, i - 1, j - 1] + softmin
#     return R

# # ----------------------------------------------------------------------------------------------------------------------
# @jit(nopython=True)
# def compute_softdtw_backward(D_, R, gamma, bandwidth):
#     B = D_.shape[0]
#     N = D_.shape[1]
#     M = D_.shape[2]
#     D = np.zeros((B, N + 2, M + 2))
#     E = np.zeros((B, N + 2, M + 2))
#     D[:, 1:N + 1, 1:M + 1] = D_
#     E[:, -1, -1] = 1
#     R[:, :, -1] = -np.inf
#     R[:, -1, :] = -np.inf
#     R[:, -1, -1] = R[:, -2, -2]
#     for k in range(B):
#         for j in range(M, 0, -1):
#             for i in range(N, 0, -1):

#                 if np.isinf(R[k, i, j]):
#                     R[k, i, j] = -np.inf

#                 # Check the pruning condition
#                 if 0 < bandwidth < np.abs(i - j):
#                     continue

#                 a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma
#                 b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma
#                 c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma
#                 a = np.exp(a0)
#                 b = np.exp(b0)
#                 c = np.exp(c0)
#                 E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c
#     return E[:, 1:N + 1, 1:M + 1]

# # ----------------------------------------------------------------------------------------------------------------------
# class _SoftDTW(Function):
#     """
#     CPU implementation based on https://github.com/Sleepwalking/pytorch-softdtw
#     """

#     @staticmethod
#     def forward(ctx, D, gamma, bandwidth):
#         dev = D.device
#         dtype = D.dtype
#         gamma = torch.Tensor([gamma]).to(dev).type(dtype)  # dtype fixed
#         bandwidth = torch.Tensor([bandwidth]).to(dev).type(dtype)
#         D_ = D.detach().cpu().numpy()
#         g_ = gamma.item()
#         b_ = bandwidth.item()
#         R = torch.Tensor(compute_softdtw(D_, g_, b_)).to(dev).type(dtype)
#         ctx.save_for_backward(D, R, gamma, bandwidth)
#         return R[:, -2, -2]

#     @staticmethod
#     def backward(ctx, grad_output):
#         dev = grad_output.device
#         dtype = grad_output.dtype
#         D, R, gamma, bandwidth = ctx.saved_tensors
#         D_ = D.detach().cpu().numpy()
#         R_ = R.detach().cpu().numpy()
#         g_ = gamma.item()
#         b_ = bandwidth.item()
#         E = torch.Tensor(compute_softdtw_backward(D_, R_, g_, b_)).to(dev).type(dtype)
#         return grad_output.view(-1, 1, 1).expand_as(E) * E, None, None

# # ----------------------------------------------------------------------------------------------------------------------
# class SoftDTW(torch.nn.Module):
#     """
#     The soft DTW implementation that optionally supports CUDA
#     """

#     def __init__(self, use_cuda, gamma=1.0, normalize=False, bandwidth=None, dist_func=None):
#         """
#         Initializes a new instance using the supplied parameters
#         :param use_cuda: Flag indicating whether the CUDA implementation should be used
#         :param gamma: sDTW's gamma parameter
#         :param normalize: Flag indicating whether to perform normalization
#                           (as discussed in https://github.com/mblondel/soft-dtw/issues/10#issuecomment-383564790)
#         :param bandwidth: Sakoe-Chiba bandwidth for pruning. Passing 'None' will disable pruning.
#         :param dist_func: Optional point-wise distance function to use. If 'None', then a default Euclidean distance function will be used.
#         """
#         super(SoftDTW, self).__init__()
#         self.normalize = normalize
#         self.gamma = gamma
#         self.bandwidth = 0 if bandwidth is None else float(bandwidth)
#         self.use_cuda = use_cuda

#         # Set the distance function
#         if dist_func is not None:
#             self.dist_func = dist_func
#         else:
#             self.dist_func = SoftDTW._euclidean_dist_func

#     def _get_func_dtw(self, x, y):
#         """
#         Checks the inputs and selects the proper implementation to use.
#         """
#         bx, lx, dx = x.shape
#         by, ly, dy = y.shape
#         # Make sure the dimensions match
#         assert bx == by  # Equal batch sizes
#         assert dx == dy  # Equal feature dimensions

#         use_cuda = self.use_cuda

#         if use_cuda and (lx > 1024 or ly > 1024):  # We should be able to spawn enough threads in CUDA
#                 print("SoftDTW: Cannot use CUDA because the sequence length > 1024 (the maximum block size supported by CUDA)")
#                 use_cuda = False

#         # Finally, return the correct function
#         return _SoftDTWCUDA.apply if use_cuda else _SoftDTW.apply

#     @staticmethod
#     def _euclidean_dist_func(x, y):
#         """
#         Calculates the Euclidean distance between each element in x and y per timestep
#         """
#         n = x.size(1)
#         m = y.size(1)
#         d = x.size(2)
#         x = x.unsqueeze(2).expand(-1, n, m, d)
#         y = y.unsqueeze(1).expand(-1, n, m, d)
#         return torch.pow(x - y, 2).sum(3)

#     def forward(self, X, Y):
#         """
#         Compute the soft-DTW value between X and Y
#         :param X: One batch of examples, batch_size x seq_len x dims
#         :param Y: The other batch of examples, batch_size x seq_len x dims
#         :return: The computed results
#         """

#         # Check the inputs and get the correct implementation
#         func_dtw = self._get_func_dtw(X, Y)

#         if self.normalize:
#             # Stack everything up and run
#             x = torch.cat([X, X, Y])
#             y = torch.cat([Y, X, Y])
#             D = self.dist_func(x, y)
#             out = func_dtw(D, self.gamma, self.bandwidth)
#             out_xy, out_xx, out_yy = torch.split(out, X.shape[0])
#             return out_xy - 1 / 2 * (out_xx + out_yy)
#         else:
#             D_xy = self.dist_func(X, Y)
#             return func_dtw(D_xy, self.gamma, self.bandwidth)

# # ----------------------------------------------------------------------------------------------------------------------
# def timed_run(a, b, sdtw):
#     """
#     Runs a and b through sdtw, and times the forward and backward passes.
#     Assumes that a requires gradients.
#     :return: timing, forward result, backward result
#     """
#     from timeit import default_timer as timer

#     # Forward pass
#     start = timer()
#     forward = sdtw(a, b)
#     end = timer()
#     t = end - start

#     grad_outputs = torch.ones_like(forward)

#     # Backward
#     start = timer()
#     grads = torch.autograd.grad(forward, a, grad_outputs=grad_outputs)[0]
#     end = timer()

#     # Total time
#     t += end - start

#     return t, forward, grads

# # ----------------------------------------------------------------------------------------------------------------------
# def profile(batch_size, seq_len_a, seq_len_b, dims, tol_backward):
#     sdtw = SoftDTW(False, gamma=1.0, normalize=False)
#     sdtw_cuda = SoftDTW(True, gamma=1.0, normalize=False)
#     n_iters = 6

#     print("Profiling forward() + backward() times for batch_size={}, seq_len_a={}, seq_len_b={}, dims={}...".format(batch_size, seq_len_a, seq_len_b, dims))

#     times_cpu = []
#     times_gpu = []

#     for i in range(n_iters):
#         a_cpu = torch.rand((batch_size, seq_len_a, dims), requires_grad=True)
#         b_cpu = torch.rand((batch_size, seq_len_b, dims))
#         a_gpu = a_cpu.cuda()
#         b_gpu = b_cpu.cuda()

#         # GPU
#         t_gpu, forward_gpu, backward_gpu = timed_run(a_gpu, b_gpu, sdtw_cuda)

#         # CPU
#         t_cpu, forward_cpu, backward_cpu = timed_run(a_cpu, b_cpu, sdtw)

#         # Verify the results
#         assert torch.allclose(forward_cpu, forward_gpu.cpu())
#         assert torch.allclose(backward_cpu, backward_gpu.cpu(), atol=tol_backward)

#         if i > 0:  # Ignore the first time we run, in case this is a cold start (because timings are off at a cold start of the script)
#             times_cpu += [t_cpu]
#             times_gpu += [t_gpu]

#     # Average and log
#     avg_cpu = np.mean(times_cpu)
#     avg_gpu = np.mean(times_gpu)
#     print("\tCPU:     ", avg_cpu)
#     print("\tGPU:     ", avg_gpu)
#     print("\tSpeedup: ", avg_cpu / avg_gpu)
#     print()

# # ----------------------------------------------------------------------------------------------------------------------
# if __name__ == "__main__":
#     from timeit import default_timer as timer

#     torch.manual_seed(1234)

#     profile(128, 17, 15, 2, tol_backward=1e-6)
#     profile(512, 64, 64, 2, tol_backward=1e-4)
#     profile(512, 256, 256, 2, tol_backward=1e-3)


##Kmeans func

In [None]:
# from functools import partial

# import os
# import numpy as np
# import torch
# from tqdm import tqdm

# # from soft_dtw_cuda import SoftDTW


# def initialize(X, num_clusters):
#     """
#     initialize cluster centers
#     :param X: (torch.tensor) matrix
#     :param num_clusters: (int) number of clusters
#     :return: (np.array) initial state
#     """
#     num_samples = len(X)
#     indices = np.random.choice(num_samples, num_clusters, replace=False)
#     initial_state = X[indices]
#     return initial_state


# def kmeans(
#         X,
#         num_clusters,
#         distance='euclidean',
#         cluster_centers=[],
#         tol=1e-4,
#         tqdm_flag=True,
#         iter_limit=0,
#         device=torch.device('cpu'),
#         gamma_for_soft_dtw=0.001
# ):
#     """
#     perform kmeans
#     :param X: (torch.tensor) matrix
#     :param num_clusters: (int) number of clusters
#     :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
#     :param tol: (float) threshold [default: 0.0001]
#     :param device: (torch.device) device [default: cpu]
#     :param tqdm_flag: Allows to turn logs on and off
#     :param iter_limit: hard limit for max number of iterations
#     :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
#     :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
#     """
#     if tqdm_flag:
#         print(f'running k-means on {device}..')

#     if distance == 'euclidean':
#         pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
#     elif distance == 'cosine':
#         pairwise_distance_function = partial(pairwise_cosine, device=device)
#     elif distance == 'soft_dtw':
#         sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
#         pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
#     else:
#         raise NotImplementedError

#     # convert to float
#     X = X.float()

#     # transfer to device
#     X = X.to(device)

#     # initialize
#     if type(cluster_centers) == list:  # ToDo: make this less annoyingly weird
#         initial_state = initialize(X, num_clusters)
#     else:
#         if tqdm_flag:
#             print('resuming')
#         # find data point closest to the initial cluster center
#         initial_state = cluster_centers
#         dis = pairwise_distance_function(X, initial_state)
#         choice_points = torch.argmin(dis, dim=0)
#         initial_state = X[choice_points]
#         initial_state = initial_state.to(device)

#     iteration = 0
#     if tqdm_flag:
#         tqdm_meter = tqdm(desc='[running kmeans]')
#     while True:

#         dis = pairwise_distance_function(X, initial_state)

#         choice_cluster = torch.argmin(dis, dim=1)

#         initial_state_pre = initial_state.clone()

#         for index in range(num_clusters):
#             selected = torch.nonzero(choice_cluster == index).squeeze().to(device)

#             selected = torch.index_select(X, 0, selected)

#             # https://github.com/subhadarship/kmeans_pytorch/issues/16
#             if selected.shape[0] == 0:
#                 selected = X[torch.randint(len(X), (1,))]

#             initial_state[index] = selected.mean(dim=0)

#         center_shift = torch.sum(
#             torch.sqrt(
#                 torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
#             ))

#         # increment iteration
#         iteration = iteration + 1

#         # update tqdm meter
#         if tqdm_flag:
#             tqdm_meter.set_postfix(
#                 iteration=f'{iteration}',
#                 center_shift=f'{center_shift ** 2:0.6f}',
#                 tol=f'{tol:0.6f}'
#             )
#             tqdm_meter.update()
#         if center_shift ** 2 < tol:
#             break
#         if iter_limit != 0 and iteration >= iter_limit:
#             break

#     return choice_cluster.cpu(), initial_state.cpu()


# def kmeans_predict(
#         X,
#         cluster_centers,
#         distance='euclidean',
#         device=torch.device('cpu'),
#         gamma_for_soft_dtw=0.001,
#         tqdm_flag=True
# ):
#     """
#     predict using cluster centers
#     :param X: (torch.tensor) matrix
#     :param cluster_centers: (torch.tensor) cluster centers
#     :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
#     :param device: (torch.device) device [default: 'cpu']
#     :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
#     :return: (torch.tensor) cluster ids
#     """
#     if tqdm_flag:
#         print(f'predicting on {device}..')

#     if distance == 'euclidean':
#         pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
#     elif distance == 'cosine':
#         pairwise_distance_function = partial(pairwise_cosine, device=device)
#     elif distance == 'soft_dtw':
#         sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
#         pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
#     else:
#         raise NotImplementedError

#     # convert to float
#     X = X.float()

#     # transfer to device
#     X = X.to(device)

#     dis = pairwise_distance_function(X, cluster_centers)
#     choice_cluster = torch.argmin(dis, dim=1)

#     #return choice_cluster.cpu()
#     return dis


# def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
#     if tqdm_flag:
#         print(f'device is :{device}')
    
#     # transfer to device
#     data1, data2 = data1.to(device), data2.to(device)

#     # N*1*M
#     A = data1.unsqueeze(dim=1)

#     # 1*N*M
#     B = data2.unsqueeze(dim=0)

#     dis = (A - B) ** 2.0
#     # return N*N matrix for pairwise distance
#     dis = dis.sum(dim=-1).squeeze()
#     return dis


# def pairwise_cosine(data1, data2, device=torch.device('cpu')):
#     # transfer to device
#     data1, data2 = data1.to(device), data2.to(device)

#     # N*1*M
#     A = data1.unsqueeze(dim=1)

#     # 1*N*M
#     B = data2.unsqueeze(dim=0)

#     # normalize the points  | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
#     A_normalized = A / A.norm(dim=-1, keepdim=True)
#     B_normalized = B / B.norm(dim=-1, keepdim=True)

#     cosine = A_normalized * B_normalized

#     # return N*N matrix for pairwise distance
#     cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
#     return cosine_dis


# def pairwise_soft_dtw(data1, data2, sdtw=None, device=torch.device('cpu')):
#     if sdtw is None:
#         raise ValueError('sdtw is None - initialize it with SoftDTW')

#     # transfer to device
#     data1, data2 = data1.to(device), data2.to(device)

#     # (batch_size, seq_len, feature_dim=1)
#     A = data1.unsqueeze(dim=2)

#     # (cluster_size, seq_len, feature_dim=1)
#     B = data2.unsqueeze(dim=2)

#     distances = []
#     for b in B:
#         # (1, seq_len, 1)
#         b = b.unsqueeze(dim=0)
#         A, b = torch.broadcast_tensors(A, b)
#         # (batch_size, 1)
#         sdtw_distance = sdtw(b, A).view(-1, 1)
#         distances.append(sdtw_distance)

#     # (batch_size, cluster_size)
#     dis = torch.cat(distances, dim=1)
#     return dis


# if __name__ == "__main__":
    
#     print("torch = {}".format(torch.__version__))
#     os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#     device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
#     feature_path = '/ceph-jd/pub/jupyter/zhaozy/notebooks/LWJ/UnsupModel/Reuslts/pretrained_weights/2022-01-01_02-09-51/feature_map.npy' 
#     save_path = '/ceph-jd/pub/jupyter/zhaozy/notebooks/LWJ/UnsupModel/Reuslts/'
#     matrix = 'euclidean'
    
#     feature_map = np.load(feature_path)
#     map1=torch.from_numpy(feature_map)
#     map2 = torch.flatten(map1, start_dim=1, end_dim=-1)
#     cluster_ids_x, cluster_centers = kmeans(
#         X=map2, num_clusters=3, distance=matrix, device=device)
    
    
   
#     print('save at:')
#     print('cluster_ids_x :',os.path.join(save_path,matrix + '_cluster_ids_x.pt'))
#     print('cluster_centers: ', os.path.join(save_path,matrix + '_cluster_centers.pt'))
    
#     torch.save(cluster_ids_x,os.path.join(save_path,matrix + '_cluster_ids_x.pt'))
#     torch.save(cluster_centers,os.path.join(save_path,matrix + '_cluster_centers.pt'))
    
#     cluster_dis = kmeans_predict(map2, cluster_centers, matrix, device=device)
    
#     torch.save(cluster_dis,os.path.join(save_path,matrix + '_cluster_centers_dis.pt'))
    
    
#     cluster_map = []
#     for i in range(3):
#         dis, idx_sort = torch.sort(cluster_dis[:,i], dim=0, descending=False)
#         cluster_map.append({'dis':dis,'idx_sort':idx_sort})
            
#     np.save(os.path.join(save_path,matrix + '_cluster.npy'),cluster_map)


##Config_cluster

In [None]:
import os
import shutil
from datetime import datetime
import pytz

class models_genesis_config:
    model = "Unet2D"
    suffix = "genesis_chest_ct"
    exp_name = model + "-" + suffix
    
    # data
    data = "/mnt/dataset/shared/zongwei/LUNA16/Self_Learning_Cubes" # not use
    scale = 32
    input_rows = 256
    input_cols = 256
    input_deps = 1
    nb_class = 1

    # image deformation
    nonlinear_rate = 0.9
    paint_rate = 0.9
    outpaint_rate = 0.8
    inpaint_rate = 1.0 - outpaint_rate
    local_rate = 0.5
    flip_rate = 0.4
    
    # logs
    model_dir = "/content/drive/MyDrive/Spring_research_2023/SSLModel/Reuslts/pretrained_weights"
    timenow = datetime.strftime(datetime.now(pytz.timezone('Asia/Singapore')), '%Y-%m-%d_%H-%M-%S')
    model_path = os.path.join(model_dir,timenow)
    print('Model path: ',model_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
        
    logs_path = os.path.join(model_path, "Logs")
    print('log path: ',logs_path)
    if not os.path.exists(logs_path):
        os.makedirs(logs_path)
        
    shotdir = os.path.join(model_path, 'snapshot')
    print('snapshot path: ',shotdir)
    if not os.path.exists(shotdir):
        os.makedirs(shotdir)
    
    # model pre-training
    verbose = 1
    weights = os.path.join(model_path,'ISIC_Unsup.pt')
    batch_size = 1
    optimizer = "sgd"
    workers = 10
    max_queue_size = workers * 4
    save_samples = "png"
    nb_epoch = 100
    patience = 50
    lr = 0.01
    
    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")


##Cluster 

In [None]:
#!/usr/bin/env python
# coding: utf-8


import warnings
warnings.filterwarnings('ignore')
import numpy as np
from torch import nn
import torch
from torchsummary import summary
import sys
# from utils import *
# import unet3d
# from unet_model2 import UNet, UNet_hidden
# from config_cluster import models_genesis_config
from tqdm import tqdm
# from data_load import Dataset_Loader
from datetime import datetime
import logging

import numpy as np
import matplotlib.pyplot as plt
# from kmeans_func import kmeans, kmeans_predict
import os
from datetime import datetime
import pytz

# =================================================
#             load data and model
# =================================================

print("torch = {}".format(torch.__version__))
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Total CUDA devices: ", torch.cuda.device_count())
img_size = [256,256]
input_rows, input_cols = 256, 256
conf = models_genesis_config()

train_path = '/content/drive/MyDrive/Spring_research_2023/data/GrayData/imgs_train.npy'
train_set = Dataset_Loader(train_path,img_size)

train_num =  1600
x_train = train_set[0:train_num]
print("x_train: {} | {:.2f} ~ {:.2f}".format(x_train.shape, np.min(x_train), np.max(x_train)))
training_generator = generate_pair(x_train,conf.batch_size, conf)

model = UNet_hidden(n_channels=1, n_classes=conf.nb_class).cpu()
model.to(device)
summary(model, (1,input_rows,input_cols), batch_size=-1)
criterion = nn.MSELoss()

# =================================================
#             extract hidden features
# =================================================
conf.weights = '/content/drive/MyDrive/Spring_research_2023/SSLModel/Reuslts/pretrained_weights/2023-04-02_02-19-33/ISIC_Unsup.pt'

if conf.weights != None:
    checkpoint=torch.load(conf.weights)
    model.load_state_dict(checkpoint['state_dict'])
    print("Loading weights from ",conf.weights)
sys.stdout.flush()

feature_map = []
for iteration in tqdm(range(int(x_train.shape[0]//conf.batch_size))):
    image, _ = next(training_generator)
    image = torch.from_numpy(image).float().to(device)
    _, feature=model(image)
    descriptors = feature.cpu().detach().numpy()
    for i in range(conf.batch_size):
        feature_map.append(descriptors[i])
print('\nsize of feature_map:',np.shape(feature_map))
np.save(os.path.join(conf.model_path,'2023-04-02_02-19-33_feature_map.npy'),feature_map)
print('path of feature map:',os.path.join(conf.model_path,'2023-04-02_02-19-33_feature_map.npy'))
'''
feature_path = '../SSLModel/Reuslts/pretrained_weights/2022-01-13_04-14-30/feature_map.npy'
feature_map = np.load(feature_path)
'''
newmap=torch.from_numpy(np.array(feature_map))

# =================================================
#           dimensionality reduction
# =================================================
class PCA(object):
    def __init__(self, n_components=2):
        self.n_components = n_components

    def fit(self, X):
        n = X.shape[0]
        self.mean = torch.mean(X, axis=0)
        X = X - self.mean
        covariance_matrix = 1 / n * torch.matmul(X.T, X)
        eigenvalues, eigenvectors = torch.eig(covariance_matrix, eigenvectors=True)
        eigenvalues = torch.norm(eigenvalues, dim=1)
        idx = torch.argsort(-eigenvalues)
        eigenvectors = eigenvectors[:, idx]
        self.proj_mat = eigenvectors[:, 0:self.n_components]

    def transform(self, X):
        X = X - self.mean
        return X.matmul(self.proj_mat)

print('========== processing dimensionality reduction ===========')
redim_type = 'pooling_512'
dim = 512
if redim_type == 'flatten_PCA':  # in paper, no use PCA
    # flatten
    flatten_map = torch.flatten(newmap, start_dim=1, end_dim=-1)
    # pca
    pca = PCA(n_components=np.shape(flatten_map)[1])
    pca.fit(flatten_map)
    X_all = pca.transform(flatten_map)
    reduced = X_all[:,0:dim]
    
elif redim_type == 'pooling_512':
    # adaptive average pool
    aap512 = nn.AdaptiveAvgPool2d((1))
    map_aap512 = aap512(newmap)
    reduced =  torch.flatten(map_aap512, start_dim=1, end_dim=-1)
    
elif redim_type == 'pooling_2048':
    aap2048 = nn.AdaptiveAvgPool2d((2,2))
    map_aap2048 = aap2048(newmap)
    reduced = torch.flatten(map_aap2048, start_dim=1, end_dim=-1)
    
print(redim_type)
print('\nfeature shape:', np.shape(reduced))

# =================================================
#                     clustering
# =================================================
dir_path = '/content/drive/MyDrive/Spring_research_2023/Cluster_Results/Hidden_features' # save path of features
matrix = 'euclidean'
num_clusters = 10

# dim_list = [512,256,128,64]
dim_list = 512
for dim in range(dim_list):
    # reduced = X_all[:,0:dim]
    x = reduced
    
    name = matrix+'_'+redim_type+'_dim'+str(dim)
    save_path = os.path.join(dir_path,name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    print('/nFeatures save under: ',save_path)

    timenow = datetime.strftime(datetime.now(pytz.timezone('Asia/Singapore')),'%Y-%m-%d_%H-%M-%S')


    cluster_ids_x, cluster_centers = kmeans(
        X=x, num_clusters=num_clusters, distance=matrix, device=device
    )
    cluster_dis = kmeans_predict(x, cluster_centers, matrix, device=device)

    print('\nsave at:')
    print('cluster_ids_x :',os.path.join(save_path,timenow + '_cluster_ids_x.pt'))
    print('cluster_centers: ', os.path.join(save_path,timenow + '_cluster_centers.pt'))
    print('cluster_distances: ', os.path.join(save_path,timenow + '_cluster_centers_dis.pt'))

    torch.save(cluster_ids_x,os.path.join(save_path,timenow + '_cluster_ids_x.pt'))
    torch.save(cluster_centers,os.path.join(save_path,timenow + '_cluster_centers.pt'))    
    torch.save(cluster_dis,os.path.join(save_path,timenow + '_cluster_centers_dis.pt'))

    cluster_map = []
    for i in range(num_clusters):
        dis, idx_sort = torch.sort(cluster_dis[:,i], dim=0, descending=False)
        cluster_map.append({'dis':dis,'idx_sort':idx_sort})

    print('cluster_distance rank: ', os.path.join(save_path,timenow + '_cluster.npy'))
    np.save(os.path.join(save_path,timenow + '_cluster.npy'),cluster_map)
