In [None]:
import torch
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
import numpy as np
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
import cv2
import os
import time
from sklearn.manifold import TSNE
from tqdm import tqdm
import matplotlib.pyplot as plt
cuda = torch.cuda.is_available()
from torchvision import datasets
from torchvision import transforms
import matplotlib as mpl
from dataset import SiameseCeliac
import argparse
from network import SupConResNet, LinearClassifier
from utils import TwoCropTransform, AverageMeter
from utils import adjust_learning_rate, warmup_learning_rate
from utils import set_optimizer, save_model
from losses import SupConLoss
import torch.backends.cudnn as cudnn
import sys
from sklearn.decomposition import PCA
from PIL import Image

In [None]:
#file_paths
directory_path_root = '/home/aayush/Aayush/Projects/Celiac_Disease/Detection/Dataset/Celiac_cropped_patches_New_data/'
train_directory_path = '/home/aayush/Aayush/Projects/Celiac_Disease/Detection/Dataset/Celiac_cropped_patches_New_data/train/'
val_directory_path = '/home/aayush/Aayush/Projects/Celiac_Disease/Detection/Dataset/Celiac_cropped_patches_New_data/val/'
test_directory_path = '/home/aayush/Aayush/Projects/Celiac_Disease/Detection/Dataset/Celiac_cropped_patches_New_data/test/'
exp_dir = '/home/aayush/Aayush/Projects/Celiac_Disease/Detection/Code/Postreg/Contrastive_learning/save/SupCon/Celiac_Disease_bal_models/SupCon_Celiac_Disease_resnet18_lr_0.0001_decay_0.0001_bsz_8_temp_0.7_trial_0_kfold_0/'
output_dir = '/home/aayush/Aayush/Projects/Celiac_Disease/Detection/Code/Postreg/Contrastive_learning/PCA/Resnet18_temp0d7/PCA_Contrastive_Balance/'

In [None]:
image_size = 224
batch_size = 1
syncBN = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def set_model():
    model = SupConResNet(name='resnet18')
    criterion = SupConLoss(temperature=0.01)
    
    # load model
    model_dir = exp_dir + 'ckpt_resnet18_epoch_250_temp_0.7.pth'
    checkpoint = torch.load(model_dir)
    model.load_state_dict(checkpoint['model'])

    # enable synchronized Batch Normalization
    if syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion

In [None]:
def set_loader():
    # load datasets
    # Data loader for inference
    means = (0.485, 0.456, 0.406)
    stds = (0.229, 0.224, 0.225)
    train_transform = transforms.Compose([transforms.Resize(image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize(means, stds)])
    test_transform = transforms.Compose([
                                    transforms.Resize(image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize(means, stds)])

    train_data = datasets.ImageFolder(train_directory_path, transform=train_transform)
    valid_data = datasets.ImageFolder(val_directory_path, transform=test_transform)
    test_data = datasets.ImageFolder(test_directory_path, transform=test_transform)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle= True)
    val_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle = False)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle = False)

    return train_loader, val_loader, test_loader

In [None]:
# build data loader
train_loader, val_loader, test_loader = set_loader()
model, criterion = set_model()

In [None]:
def generate_embeddings(data_loader, model):
    with torch.no_grad():
        model.eval()
        labels = None
        embeddings = None
        for batch_idx, data in tqdm(enumerate(data_loader)):
            batch_imgs, batch_labels = data
            batch_labels = batch_labels.numpy()
            batch_imgs = Variable(batch_imgs.to(device))
            bacth_E = model.encoder(batch_imgs)
            bacth_E = bacth_E.data.cpu().numpy()
            embeddings = np.concatenate((embeddings, bacth_E), axis=0) if embeddings is not None else bacth_E
            labels = np.concatenate((labels, batch_labels), axis=0) if labels is not None else batch_labels
    return embeddings, labels

In [None]:
embeddings_train, labels_train = generate_embeddings(train_loader, model)
embeddings_val, labels_val = generate_embeddings(val_loader, model)
embeddings_test, labels_test = generate_embeddings(test_loader, model)

print(embeddings_test.shape , labels_test.shape)

In [None]:
def vis_tSNE(embeddings, labels, split = 'train'):
    tSNE_ns = 10000
    num_samples = tSNE_ns if tSNE_ns < embeddings.shape[0] else embeddings.shape[0]
    pca = PCA(n_components= 32, svd_solver='full', random_state=1001)
    X_pca = pca.fit_transform(embeddings[0:num_samples, :])
    X_embedded = TSNE(n_components=2).fit_transform(X_pca)
    fig, ax = plt.subplots()

    x, y = X_embedded[:, 0], X_embedded[:, 1]
    colors = plt.cm.rainbow(np.linspace(0, 1, 10))
    sc = ax.scatter(x, y, c=labels[0:num_samples], cmap=mpl.colors.ListedColormap(colors))
    plt.colorbar(sc)
    if not os.path.exists(os.path.join(exp_dir, 'tSNE')):
        os.makedirs(os.path.join(exp_dir, 'tSNE'))
#     plt.savefig(os.path.join(exp_dir, 'tSNE', 'tSNE_edge_{}_'.format(split) + str(num_samples) + '.jpg'))
    plt.show()

vis_tSNE(embeddings_train, labels_train)

In [None]:
from sklearn.decomposition import PCA
import pickle as pk
# Make an instance of the Model
pca = PCA(.90)
pca.fit(embeddings_train)
print(pca.n_components_)
pk.dump(pca, open(os.path.join(output_dir, "pca_Celiac_contrastive_balance_new_data.pkl"),"wb"))