In [1]:
import os
import sys
import time

import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from tqdm import tqdm_notebook as tqdm

from matplotlib import pyplot as plt
from scipy.ndimage.interpolation import rotate as sc_rotate
from torchvision import datasets, transforms
from scipy.optimize import curve_fit, least_squares, minimize

from pandas import read_fwf, DataFrame

In [2]:
from radioreader import *
from methods import *
from kittler import kittler_float

from InterpretableAE import *

In [3]:
directory = 'unlrg'
ext = 'fits'
names = glob.glob('{0}/*.{1}*'.format(directory, ext))

In [4]:
images = []
for n in tqdm(range(len(names))):
    im = readImg(names[n], normalize=True, sz=128)
    k = kittler_float(im, copy=False)
    images.append( np.expand_dims(k, axis=0) )
    del im
    del k
# images = np.array(images)

HBox(children=(IntProgress(value=0, max=14245), HTML(value='')))




In [5]:
im_tensor = torch.tensor(images)
im_tensor.shape

torch.Size([14245, 1, 128, 128])

In [6]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
print(device)

#arguments
batch_size = 64
test_batch_size = 64

ts = list(im_tensor.shape)
imh, imw = ts[2], ts[3]
print(im_tensor.shape)
print('imh, imw = ',imh, imw)

epochs = 100
learning_rate = 0.001
momentum = 0.9 # 'SGD momentum'
latent_space=16

cuda
torch.Size([14245, 1, 128, 128])
imh, imw =  128 128


In [7]:
model = InterpretableAE(imh, imw, device, latent_dim=latent_space).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_loader = torch.utils.data.DataLoader(im_tensor, batch_size=batch_size, shuffle=True)

In [8]:
def train(model, device, train_loader, optimizer, epoch, log_interval=5):
    model.train()
    for batch_idx, data in enumerate(train_loader):
        # Reshape data
        rot_data = random_rotate(data.numpy())
        targets, angles = rotate_tensor(rot_data)
        targets = torch.from_numpy(targets).to(device, dtype=torch.float)
        angles = torch.from_numpy(angles).to(device)
        angles = angles.view(angles.size(0), 1)

        # Forward pass
        rot_data = torch.from_numpy(rot_data).to(device, dtype=torch.float)
        optimizer.zero_grad()
        output = model(rot_data, angles)

        # Binary cross entropy loss
        loss_fnc = nn.BCELoss(reduction='sum')
        loss = loss_fnc(output, targets)

        # Backprop
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            sys.stdout.write('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\r'
                .format(epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            sys.stdout.flush()

In [None]:
%%time
for epoch in range(1, epochs + 1):
    start = time.time()
    train(model, device, train_loader, optimizer, epoch)
    end = time.time()
    sys.stdout.write('\n Time: {0:.2f}s\n'.format(end - start))

 Time: 58.45s
 Time: 56.95s
 Time: 57.19s
 Time: 56.44s
 Time: 56.40s
 Time: 56.32s
 Time: 56.62s
 Time: 56.69s
 Time: 56.90s
 Time: 56.48s
 Time: 57.06s
 Time: 56.77s
 Time: 57.22s
 Time: 56.32s
 Time: 56.21s
 Time: 56.71s
 Time: 56.74s
 Time: 56.41s
 Time: 56.28s
 Time: 56.24s
 Time: 56.24s
 Time: 56.19s
 Time: 56.09s
 Time: 56.23s
 Time: 56.16s
 Time: 56.47s
 Time: 56.18s
 Time: 56.14s
 Time: 56.25s
 Time: 56.19s
 Time: 56.08s
 Time: 56.08s
 Time: 56.15s
 Time: 56.17s
 Time: 56.23s
 Time: 56.25s
 Time: 56.20s
 Time: 56.09s
 Time: 56.21s
 Time: 56.26s
 Time: 57.21s
 Time: 57.07s
 Time: 56.81s
 Time: 56.23s
 Time: 56.13s
 Time: 56.17s
 Time: 56.01s
 Time: 56.21s
 Time: 56.25s
 Time: 56.24s
 Time: 56.24s
 Time: 56.21s
 Time: 56.20s
 Time: 56.26s
 Time: 56.23s
 Time: 56.23s
 Time: 56.24s
 Time: 56.19s
 Time: 56.12s
 Time: 56.15s
 Time: 56.28s
 Time: 56.19s
 Time: 56.21s
 Time: 56.23s
 Time: 56.10s
 Time: 56.16s
 Time: 56.18s
 Time: 56.25s
 Time: 56.23s
 Time: 56.18s
 Time: 56.37s
 Time:

In [None]:
torch.save(model.state_dict(), 'unlrg_conv_model_latent16')