In [56]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import matplotlib.pyplot as plt

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os.path as osp
from PIL import Image
from PIL import ImageEnhance
import matplotlib.cm as cm

import torchvision.datasets as dset
import torchvision.transforms as T
import chest_xray_code.data.xrays as preprocess_dataset
import chest_xray_code.data.raw_reports as utils
import os
import torch.nn.functional as F
from models.NewConvModel import NewConvNet 
from models.TestConvNet import TestConvNet
from loaders.XrayLoader import XrayLoader
from loaders.BloodCellLoader import BloodCellLoader
from loaders.MuseumLoader import MuseumLoader

import numpy as np

In [7]:
xray_set = XrayLoader(
    root='chest_xray_code/data/xrays',
    preload=False, transform=transforms.ToTensor(),
)
xray_loader = DataLoader(xray_set, batch_size=20, shuffle=True, num_workers=32)

blood_set = BloodCellLoader(
    root='blood_cells_data/dataset-master/JPEGImages',
    preload=False, transform=transforms.ToTensor(),
)

blood_cell_loader = DataLoader(blood_set, batch_size=20, shuffle=True, num_workers=32)


museum_set = MuseumLoader(
    root='museum_data/dataset_updated/training_set',
    preload=False, transform=transforms.ToTensor(),
)

museum_loader = DataLoader(museum_set, batch_size=20, shuffle=True, num_workers=32)

print(len(museum_set))
print(len(blood_set))
print(len(xray_set))

500
366
500


In [8]:
transform = T.Compose([
                T.ToTensor()
                #T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

images = {}

In [9]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    #dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 5


In [16]:
channels = 3
size = 32
net_type = 'old'#'new' #'old'
original_dataset = 'blood_cell' #'museum' #'xray' 
cross_dataset = 'blood_cell'
model = None
hook = None

In [17]:
if model == None: #ensure that we don't keep loading it over and over
    if net_type == 'new':
        model = NewConvNet(channels,size,device)
        if original_dataset == 'blood_cell':
            model = torch.load('trained_models/blood5.pt')
            print("loaded blood5.pt")
        elif original_dataset == 'museum':
            model = torch.load('trained_models/new_museum.pt')
            print("loaded new_museum.pt")
        else:
            model = torch.load('trained_models/new_xray.pt')
            print("loaded new_xray.pt")
    else:
        model = TestConvNet(channels,size)
        if original_dataset == 'blood_cell':
            model = torch.load('trained_models/blood_200_1000.pt')
            print("loaded blood_200_1000.pt")
        elif original_dataset == 'museum':
            model = torch.load('trained_models/old_museum.pt')
            print("loaded old_museum.pt")
        else:
            model = torch.load('trained_models/xraymodelV2.pt')
            print("loaded xraymodelV2.pt")

        model.to(device)
        
        
if cross_dataset == 'museum':
    cross_dataset_loader = museum_loader
elif cross_dataset == 'xray':
    cross_dataset_loader = xray_loader
else:
    cross_dataset_loader = blood_cell_loader


loaded blood_200_1000.pt




Visualize the Compressed images 

In [18]:
def prep(img):
    img = img.numpy()
    if True:
        img_max, img_min = np.max(img), np.min(img)
        img = .9* (img - img_min) / (img_max - img_min) 
    return np.transpose(img, (1, 2, 0)) 

def save_compressed(self,input,output):
    #for i in range(1):
    img = output.cpu().detach()
    for i in range(img.shape[0]):
        individual_img = img[i]
        images[1].append(prep(individual_img))

        
def save_original(data):
    img = data.cpu().detach()

    for i in range(img.shape[0]):
        individual_img = img[i]
        individual_img = individual_img.numpy()
        individual_img = np.transpose(individual_img, (1, 2, 0))
        images[0].append(individual_img)

def visualize(net_type,original_dataset,cross_dataset,device):
    model = None
    

    if model == None: #ensure that we don't keep loading it over and over
        if net_type == 'new':
            model = NewConvNet(channels,size,device)
            if original_dataset == 'blood_cell':
                model = torch.load('trained_models/blood5.pt')
                print("loaded blood5.pt")
            elif original_dataset == 'museum':
                model = torch.load('trained_models/new_museum.pt')
                print("loaded new_museum.pt")
            else:
                model = torch.load('trained_models/new_xray.pt')
                print("loaded new_xray.pt")
        else:
            model = TestConvNet(channels,size)
            if original_dataset == 'blood_cell':
                model = torch.load('trained_models/blood_200_1000.pt')
                print("loaded blood_200_1000.pt")
            elif original_dataset == 'museum':
                model = torch.load('trained_models/old_museum.pt')
                print("loaded old_museum.pt")
            else:
                model = torch.load('trained_models/xraymodelV2.pt')
                print("loaded xraymodelV2.pt")

            model.to(device)
            
    if hook == None: 
        hook = model.conv_compress_final.register_forward_hook(save_compressed)     
        
        
    if cross_dataset == 'museum':
        cross_dataset_loader = museum_loader
    elif cross_dataset == 'xray':
        cross_dataset_loader = xray_loader
    else:
        cross_dataset_loader = blood_cell_loader
    
    
    
    
    plt.close("all")


    images = [[],[]]


    i = 0
    reconstruction = None
    for data in cross_dataset_loader:
        with torch.no_grad():
            if i > 0: break
            data = data.to(device)
            save_original(data)
            reconstruction = model(data)
            i+=1

    for i in range(10):
        plt.figure(figsize=(100,100))
        #plt.figure()
        org = images[0][i]
        plt.subplot(1, 3, 1)
        plt.axis('off')
        #imshow_noax(org, normalize=False)
        plt.imshow(org)
        plt.title('Original')
        plt.subplot(1, 3, 2)
        rec = images[1][i]
        plt.imshow(rec)
        #imshow_noax(rec, normalize=False)
        plt.title('Compressed')
        plt.axis('off')
        image_str = net_type + "_" + original_dataset + "_" +cross_dataset+"_"+str(i) + ".png"
        plt.subplot(1, 3, 3)
        plt.axis('off')
        recon = reconstruction[i].cpu().detach()
        recon = recon.numpy()
        recon = np.transpose(recon,(1,2,0))
        recon = np.clip(recon,0,1)
        plt.imshow(recon)
        plt.title('Reconstructed')
        plt.savefig("cross_visualizations/"+image_str)
        plt.show()



Store Images as JPEG with same compression rate that we are achieving 

In [None]:
jpg_images = [[],[],[]]

def check_sizes(img):
    for i in range(1,100):
        jpeg_filename = "jpeg_visualizations/" + str(i) + "_jaypeg.jpeg"
        img.save(jpeg_filename,"JPEG",quality=i)
        jpeg_compressed = Image.open(jpeg_filename)
        jpeg_compressed = np.asarray(jpeg_compressed)
        print(jpeg_compressed.shape,i)
        
        

def save_original_and_jpeg(data):
    img = data.cpu().detach()

    for i in range(img.shape[0]):
        individual_img = img[i]
        individual_img = individual_img.numpy()
        individual_img = np.transpose(individual_img, (1, 2, 0))
        jpg_images[0].append(individual_img)

        rescaled = (255.0 * individual_img)
        rescaled = rescaled.astype('uint8')
        
        PIL_img = Image.fromarray(rescaled)
        #check_sizes(PIL_img)
        jpeg_filename = "jpeg_visualizations/" + str(i) + "_jaypeg.jpeg"
        PIL_img.save(jpeg_filename,"JPEG",quality=100)
     
        jpeg_compressed = Image.open(jpeg_filename)
        jpeg_compressed = np.asarray(jpeg_compressed)
        jpg_images[2].append(jpeg_compressed)
        
def save_reconstruction(reconstruction):
    for i in range(reconstruction.shape[0]):
        recon = reconstruction[i].cpu().detach()
        recon = recon.numpy()
        recon = np.transpose(recon,(1,2,0))
        recon = np.clip(recon,0,1)
        jpg_images[1].append(recon)
        
i = 0
for data in cross_dataset_loader:
        with torch.no_grad():
            if i > 0: break
            data = data.to(device)
            save_original_and_jpeg(data)
            reconstruction = model(data)
            save_reconstruction(reconstruction)
            i+=1
            
for i in range(10):
        plt.figure(figsize=(100,100))
        #plt.figure()
        org = jpg_images[0][i]
        plt.subplot(1, 3, 1)
        plt.axis('off')
        #imshow_noax(org, normalize=False)
        plt.imshow(org)
        plt.title('Original')
        plt.subplot(1, 3, 2)
        rec = jpg_images[1][i]
        plt.imshow(rec)
        #imshow_noax(rec, normalize=False)
        plt.title('Reconstructed')
        plt.axis('off')
        image_str = net_type + "_" + original_dataset + "_" +cross_dataset+"_"+str(i) + ".png"
        plt.subplot(1, 3, 3)
        plt.axis('off')
        plt.imshow(jpg_images[2][i])
        plt.title('jpeg')
        #plt.savefig("cross_visualizations/"+image_str)
        plt.show()

Not working until all models are trained

In [None]:
cross_datasets = ['xray','museum','blood_cell']#'museum' #'xray' #'blood_cell'
net_types = ['new','old'] #'old'
original_datasets = ['xray','museum','blood_cell'] #'museum' #'xray' 

for net_type in net_types:
    for original_dataset in original_datasets:
        for cross_dataset in cross_datasets:
            visualize(net_type,original_dataset,cross_dataset,device)
    


Visualize the Reconstructed Images