In [None]:
import torch
import torchvision
from torch.utils.data import Dataset
from torch.autograd import Variable
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import sys 
import os 
import scipy.io as sio
import random
from sklearn.decomposition import PCA
import spectral

In [None]:
import torch.nn as nn

In [None]:
# Global variables   
im_width, im_height, im_depth, im_channel = 11,11,30,1  

In [None]:
class MAML(nn.Module) :
  def __init__(self) :
    super(MAML,self).__init__()
    self.softmax = nn.Softmax(dim=1)
    layers = [nn.Conv3d(1,8,(7,3,3))]
    layers.append(nn.ReLU())
    layers.append(nn.BatchNorm3d(8))
    layers.append(nn.Dropout(0.5))
    layers.append(nn.Conv3d(8,16,(5,3,3)))
    layers.append(nn.ReLU())
    layers.append(nn.BatchNorm3d(16))
    layers.append(nn.Dropout(0.5))
    layers.append(nn.Conv3d(16,32,(3,3,3)))
    layers.append(nn.ReLU())
    layers.append(nn.BatchNorm3d(32))
    layers.append(nn.Flatten())
    layers.append(nn.Dropout(0.5))
    layers.append(nn.Linear(14400,256))
    layers.append(nn.Dropout(0.5))
    layers.append(nn.Linear(256,128))
    layers.append(nn.Linear(128,15))
    self.model = nn.Sequential(*layers)
  def forward(self,x) :
    y = self.model(x)
    z = self.softmax(y)
    return z


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
maml_model = MAML().to(device)

In [None]:
def loadData(name):
    if name == 'IP':
        data = sio.loadmat('Houston.mat')['Houston']
        labels = sio.loadmat('Houston_gt.mat')['Houston_gt']
    if name == 'salinas' :
        data = sio.loadmat('Salinas.mat')['salinas']
        labels = sio.loadmat('Salinas_gt.mat')['salinas_gt']
    if name == 'pavia' :
        data = sio.loadmat('PaviaU.mat')['paviaU']
        labels = sio.loadmat('PaviaU_gt.mat')['paviaU_gt']
    if name == 'ksc' :
        data = sio.loadmat('KSC.mat')['KSC']
        labels = sio.loadmat('KSC_gt.mat')['KSC_gt']
    if name == 'botswana' :
        data = sio.loadmat('Botswana.mat')['Botswana']
        labels = sio.loadmat('Botswana_gt.mat')['Botswana_gt']
    return data, labels
# without reduction of 200 channels to 30 channels, memory error while creating cube 
def applyPCA(X, numComponents):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0],X.shape[1], numComponents))
    return newX, pca

def padWithZeros(X, margin):
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

def createImageCubes(X, y, windowSize, removeZeroLabels = True):
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)  # X :(145, 145, 30) --> (195, 195, 30) with window =25
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]))  # (21025, 25, 25, 30)   
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]))  # (21025,)
    patchIndex = 0
    
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]  
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]            
            patchIndex = patchIndex + 1
  
    patchesData = np.expand_dims(patchesData, axis=-1)
    return patchesData,patchesLabels

In [None]:
# creating input patches from the salinas dataset 
dataset1 = 'IP'                                         # 16 classes   
sa_x1, sa_y = loadData(dataset1)                              #((512, 217, 204), (512, 217))
sa_x2,pca = applyPCA(sa_x1,numComponents=30)                   # ((512, 217, 20), (512, 217))
sa_X,sa_Y = createImageCubes(sa_x2, sa_y, windowSize=11)   #(111104, 9, 9, 20, 1) (111104,)
print(sa_X.shape,sa_Y.shape)

In [None]:
# creating input patches from the salinas dataset 
dataset1IP = 'salinas'                                         # 16 classes   
sa_x1IP, sa_yIP = loadData(dataset1IP)                              #((512, 217, 204), (512, 217))
sa_x2IP,pca = applyPCA(sa_x1IP,numComponents=30)                   # ((512, 217, 20), (512, 217))
sa_XIP,sa_YIP = createImageCubes(sa_x2IP, sa_yIP, windowSize=11)   #(111104, 9, 9, 20, 1) (111104,)
print(sa_XIP.shape,sa_YIP.shape)

In [None]:
def patches_class(X,Y,n) :
  n_classes = n
  patches_list = []
  for i in range(1,n_classes+1):   # not considering class 0
    patchesData_Ith_Label = X[Y==i,:,:,:,:]
    patches_list.append(patchesData_Ith_Label)
  return patches_list 

In [None]:
patches_class_salinas = patches_class(sa_X,sa_Y,15)

In [None]:
patches_class_IP = patches_class(sa_XIP,sa_YIP,16)#test

In [None]:
train_class_indices = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
test_class_indices = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
train_patches_class = [patches_class_salinas[i] for i in train_class_indices]        #(10)
test_patches_class = [patches_class_IP[i] for i in test_class_indices]        #(6) 
train_class_labels = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]   
test_class_labels = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]  

In [None]:
C = 15  # n_class
K1 = 10   # n_support
N = 20   # n_query
tC = 16  # classes in a test episode
im_height,im_width,im_depth = 11,11,30

In [None]:
def new_episode(patches_list,K,C,N,class_labels) :
  selected_classes = class_labels
  tsupport_patches = []
  tquery_patches = []
  query_labels = []
  support_labels = []
  for x in selected_classes :
    sran_indices = np.random.choice(len(patches_list[x-1]),K,replace=False)  # for class no X-1: select random sample no
    support_patches = patches_list[x-1][sran_indices,:,:,:,:]
    qran_indices = np.random.choice(len(patches_list[x-1]),N,replace=False)  # N Samples for Query
    query_patches = patches_list[x-1][qran_indices,:,:,:,:]
    for i in range(N) :
      query_labels.append(x)   
    for i in range(K) :
      support_labels.append(x)    
    tquery_patches.extend(query_patches)
    tsupport_patches.extend(support_patches)
  temp1 = list(zip(tquery_patches, query_labels)) 
  random.shuffle(temp1)        
  tquery_patches, query_labels = zip(*temp1)
  temp2 = list(zip(tsupport_patches, support_labels)) 
  random.shuffle(temp2)        
  tsupport_patches, support_labels = zip(*temp2)
  tquery_patches = torch.from_numpy(np.reshape(np.asarray(tquery_patches, dtype=np.float32),(C*N,im_height,im_width,im_depth,1)))
  tsupport_patches = torch.from_numpy(np.reshape(np.asarray(tsupport_patches, dtype=np.float32),(C*K,im_height,im_width,im_depth,1)))
  tquery_patches = tquery_patches.permute(0,4,3,2,1)
  tsupport_patches = tsupport_patches.permute(0,4,3,2,1)
  return tquery_patches, tsupport_patches, query_labels, support_labels, selected_classes 

In [None]:
tquery_patches, tsupport_patches, query_labels, support_labels, selected_classes = new_episode(patches_class_salinas,K1,C,N,train_class_labels)

In [None]:
meta_opt = torch.optim.Adam(maml_model.parameters(), lr=0.0001, betas=(0.5, 0.999))      


In [None]:
maml_model.parameters

In [None]:
checkpoint_dir1 = 'houston2/ckpts'
checkpoint_prefix1 = os.path.join(checkpoint_dir1, "ckpt")

In [None]:
np.random.seed(123)

In [None]:
for i in range(16):
    np.random.shuffle(test_patches_class[i])

In [None]:
tune_set_5 = [[] for i in range(16)]
for j in range(16) :
  tune_set_5[j] = test_patches_class[j][:10,:,:,:,:]   # for each class first 5 samples taken

In [None]:
print(len(tune_set_5))
print(tune_set_5[0].shape)

In [None]:
C = 16

In [None]:
def tune_episode(tune_set,tK,tN,test_class_labels) :
  selected_classes = test_class_labels
  support_labels  = []
  query_labels = []
  support_patches = []
  query_patches = []
  for x in selected_classes :
    y = test_class_labels.index(x)
    np.random.shuffle(tune_set[y])    
    support_imgs = tune_set[y][:tK,:,:,:,:]    #Support 1, Query 4
    query_imgs = tune_set[y][tK:10,:,:,:,:]
    support_patches.extend(support_imgs)
    query_patches.extend(query_imgs)
    for i in range(tN) :
      query_labels.append(x)
    for i in range(tK) :
      support_labels.append(x)
  temp1 = list(zip(query_patches, query_labels)) 
  random.shuffle(temp1) 
  query_patches, query_labels = zip(*temp1)
  query_patches = torch.from_numpy(np.reshape(np.asarray(query_patches,dtype=np.float32),(tC*tN,im_height,im_width,im_depth,1)))
  support_patches = torch.from_numpy(np.reshape(np.asarray(support_patches,dtype=np.float32),(tC*tK,im_height,im_width,im_depth,1)))
  query_patches = query_patches.permute(0,4,3,2,1)
  support_patches = support_patches.permute(0,4,3,2,1)
  return query_patches, support_patches, query_labels, support_labels, list(selected_classes)  

In [None]:
checkpoint_prefixa = 'houston2/ckpts/ckpt399439479'

In [None]:
checkpoint_dir2 = 'tuninghouston/ckpts'
checkpoint_prefix2 = os.path.join(checkpoint_dir2, "ckpt")

In [None]:
checkpoint_tune = torch.load(checkpoint_prefixa)
maml_model.load_state_dict(checkpoint_tune['model_state_dict'])
meta_opt.load_state_dict(checkpoint_tune['optimizer_state_dict'])

In [None]:
maml_model.model=nn.Sequential(*list(maml_model.model.children())[:-1])

In [None]:
maml_model.model.add_module('extra',nn.Linear(128,16))

In [None]:
maml_model.to(device)

In [None]:
def cutout(img, length, num_band):


    c, h, w = np.shape(img)[2], np.shape(img)[3], np.shape(img)[4]

    data = img
    RandPerm = np.random.permutation(c)
    for i in range(len(RandPerm)//num_band):
        img_c = img[RandPerm[i], :, :]
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - length // 2, 0, h)
        y2 = np.clip(y + length // 2, 0, h)
        x1 = np.clip(x - length // 2, 0, w)
        x2 = np.clip(x + length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0

        img_c *= mask
        img_c = img_c[np.newaxis, :, :]
        data[RandPerm[i], :, :] = img_c
        

    img[2]=data[0]
    img[3]=data[1]
    img[4]=data[2]

    return img


In [None]:
# Fine-tuning
n_episodes = 300
epochs = 300
import higher
n_tasks = 16
K2 = 5
N2 = 5

for k in range(epochs) :
    tune_accuracies = []
    maml_model.train()
    total_loss = 0
    accuracies = []
    n_inner_iter = 16
    inner_opt = torch.optim.SGD(maml_model.parameters(), lr=1e-1)
    meta_opt.zero_grad()
    for i in range(n_tasks) :
      with higher.innerloop_ctx(maml_model, inner_opt, copy_initial_weights=False) as (fnet, diffopt): 
        query_patches, support_patches, query_labels, support_labels, selected_classes = tune_episode(tune_set_5,5,5,test_class_labels) 
        support_patches = cutout(support_patches, 2, 10)
        query_patches = cutout(query_patches, 2, 10)
        support_y = np.zeros((int(C*K2),C))
        support_patches = support_patches.to(device)
        query_patches = query_patches.to(device)                                           
        for i in range(int(C*K2)) :
          x = selected_classes.index(support_labels[i])                           # creation of 1-hot for true labels
          support_y[i][x] = 1. 
        support_y = torch.from_numpy(support_y).to(device)
        query_y = np.zeros((int(C*N2),C))                                           
        for i in range(int(C*N2)) :
          x = selected_classes.index(query_labels[i])                           # creation of 1-hot for true labels
          query_y[i][x] = 1.
        query_y = torch.from_numpy(query_y).to(device)
        for i in range(n_inner_iter) :
          s_logits_t = fnet(support_patches)
          i_loss = F.cross_entropy(s_logits_t,support_y.argmax(axis=1))
          z_loss2 = i_loss
          z_loss = torch.mean(z_loss2)
          diffopt.step(z_loss)
        q_logits = fnet(query_patches)
        q_real = query_y.argmax(axis=1)
        q_real = (q_real).long()
        q_loss = F.cross_entropy(q_logits,q_real)
        qzq2_loss = q_loss
        qzq2_loss.detach()
        total_loss += qzq2_loss
        qzq2_loss.backward()
        eq = (q_logits.argmax(dim=1) == q_real).sum().item()
        accuracy = eq/len(q_real)  
        tune_accuracies.append(accuracy)
    meta_opt.step()
    tune_accuracies = torch.from_numpy(np.asarray(tune_accuracies))
    print(k,'Loss',total_loss,..., 'Accuracy', torch.mean(tune_accuracies))
    if (k+1)%2==0 :
       torch.save({'model_state_dict': maml_model.state_dict(),
            'optimizer_state_dict': meta_opt.state_dict(),
            'loss': total_loss,
            }, checkpoint_prefix2)

In [None]:
def test_episode(test_patches_class,test_class_labels,test_C,test_K) :
  selected_classes = test_class_labels # [1, 2, 3, 4, 5, 6, 7, 8]
  support_labels = []
  query_labels = []
  support_patches = []
  query_patches = []
  for x in selected_classes :
    y = test_class_labels.index(x)
    support_imgs = test_patches_class[y][:test_K,:,:,:,:]
    query_imgs = test_patches_class[y][test_K:,:,:,:,:]
    support_patches.extend(support_imgs)
    query_patches.extend(query_imgs)
    for i in range(query_imgs.shape[0]) :
      query_labels.append(x)
    for i in range(test_K) :
      support_labels.append(x)
  temp1 = list(zip(query_patches, query_labels)) 
  random.shuffle(temp1) 
  query_patches, query_labels = zip(*temp1)
  x = len(query_labels)
  query_patches = torch.from_numpy(np.reshape(np.asarray(query_patches,dtype=np.float32),(x,im_height,im_width,im_depth,1)))
  support_patches = torch.from_numpy(np.reshape(np.asarray(support_patches,dtype=np.float32),(test_C*test_K,im_height,im_width,im_depth,1)))
  query_patches = query_patches.permute(0,4,3,2,1)
  support_patches = support_patches.permute(0,4,3,2,1)
  return query_patches, support_patches, query_labels, support_labels,x, list(selected_classes)      

In [None]:
checkpoint_prefixb = 'tuninghouston/ckpts/ckpt'

In [None]:
checkpoint_tune = torch.load(checkpoint_prefixb)
maml_model.load_state_dict(checkpoint_tune['model_state_dict'])
meta_opt.load_state_dict(checkpoint_tune['optimizer_state_dict'])

In [None]:
K1 = 5 
#K1 = 10

In [None]:
maml_model.to(device)

In [None]:
# Testing 
C = 16
for epoch in range(1): 
    import higher   
    n_tasks = 1
    total_loss = 0
    n_inner_iter = 16
    inner_opt = torch.optim.SGD(maml_model.parameters(), lr=1e-1)
    meta_opt.zero_grad()
    for i in range(n_tasks) :
      with higher.innerloop_ctx(maml_model, inner_opt, copy_initial_weights=False) as (fnet, diffopt): 
        tquery_patches1, tsupport_patches1, query_labels1, support_labels1, x1, selected_classes1 = test_episode(test_patches_class,test_class_labels,16,5)#10        
        support_y1 = np.zeros((int(C*K1),C))
        tsupport_patches1 = tsupport_patches1.to(device)
        tquery_patches1 = tquery_patches1.to(device)                                             
        for i in range(int(C*K1)) :
          x = selected_classes1.index(support_labels1[i])                           # creation of 1-hot for true labels
          support_y1[i][x] = 1. 
        support_y1 = torch.from_numpy(support_y1).to(device)
        query_y1 = np.zeros((int(x1),C))                                         
        for i in range(int(x1)) :
          x = selected_classes1.index(query_labels1[i])                           # creation of 1-hot for true labels
          query_y1[i][x] = 1.
        query_y1 = torch.from_numpy(query_y1).to(device)
        for i in range(n_inner_iter) :
          s_logits_t = fnet(tsupport_patches1)
          i_loss = F.cross_entropy(s_logits_t,support_y1.argmax(axis=1))
          z_loss2 = i_loss
          z_loss2 = torch.mean(z_loss2)
          diffopt.step(z_loss2)
        q_logits1 = fnet(tquery_patches1)
        q_real1 = query_y1.argmax(axis=1)
        q_real1 = (q_real1).long()
        q_loss = F.cross_entropy(q_logits1,q_real1)
        qzq3_loss = q_loss
        qzq3_loss.detach()
        total_loss += qzq3_loss
        #q_loss.backward()
        eq1 = (q_logits1.argmax(dim=1) == q_real1).sum().item()
        accuracy1 = eq1/len(q_real1)  
        print(accuracy1) 
        mean_correct_class = [[] for i in range(C)]
        mean_correct_pred = [[] for i in range(C)]
        classwise_mean_acc = [[] for i in range(C)]
        q_pred = q_logits1.argmax(dim=1)
        for i in range(int(x1)):
          x = selected_classes1.index(query_labels1[i])
          mean_correct_class[x].append('4')
          if(q_pred[i] == x) :
            mean_correct_pred[x].append('4')
        for i in range(C) :
           z = len(mean_correct_pred[i])/len(mean_correct_class[i])
           classwise_mean_acc[i].append(z)
           print(classwise_mean_acc[i])