In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import numpy as np
import json
from pprint import pprint
import matplotlib.pyplot as plt
from skimage import io
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import os
import time

device = 'cuda'

In [0]:
os.chdir('/content/drive/My Drive/Colab Notebooks/measure_model_quarters')

In [0]:
images = np.load('preprocessed/quarter_measure_data_images_preprocessed.npy')
subsequences = np.load('preprocessed/quarter_measure_data_subsequences.npy')
image_indices = np.load('preprocessed/quarter_measure_data_subsequence_image_indices.npy')
measure_data_channels = np.load('preprocessed/quarter_measure_data_time_and_key_tiles.npy')
with open('preprocessed/other_data.json') as f:
    other_data = json.load(f)

In [0]:
word_to_ix = other_data['lexicon']['word_to_ix']
ix_to_word = other_data['lexicon']['ix_to_word']
len_lexicon = len(list(word_to_ix))

In [0]:
batch_size = 64
seq_len = 64
lstm_hidden_size = 5
fc1_output_size = 10

In [0]:
class ConvSubunit(nn.Module):
    def __init__(self, input_size, output_size, filter_size, stride, padding, dropout):
        super().__init__()
        self.conv = nn.Conv2d(input_size, output_size, filter_size, stride=stride, padding=padding)
        self.dp = nn.Dropout2d(p=dropout)
        self.bn = nn.BatchNorm2d(output_size)
        self.relu = nn.ReLU()
        self.sequential = nn.Sequential(self.conv, self.dp, self.bn, self.relu)

    def forward(self, x):
        return self.sequential(x)

class ConvUnit(nn.Module):
    def __init__(self, input_size, output_size, filter_size, stride, padding, dropout):
        super().__init__()
        self.subunit1 = ConvSubunit(input_size, output_size, filter_size, stride, padding, dropout)
        
    def forward(self, x):
        x = self.subunit1(x)
        return x

class Net(nn.Module):
    def __init__(self, len_lexicon, lstm_hidden_size, fc1_output_size, device):
        super().__init__()
        self.len_lexicon = len_lexicon
        self.lstm_hidden_size = lstm_hidden_size
        self.fc1_output_size = fc1_output_size
        self.num_iterations = 0
        self.train_time = 0
        self.cnn = nn.Sequential(ConvUnit(2, 64, 3, 2, 1, 0.25), # (200, 200) --> (100, 100)
                                 ConvUnit(64, 128, 3, 2, 1, 0.25), # (100, 100) --> (50, 50)
                                 ConvUnit(128, 128, 3, 5, 1, 0.25), # (50, 50) --> (10, 10)
                                 ConvUnit(128, 128, 3, 5, 1, 0.25)) # (10, 10) --> (2, 2)
        self.fc1 = nn.Linear(512, self.fc1_output_size)
        self.embed = nn.Embedding(num_embeddings=self.len_lexicon, embedding_dim=5)
        self.lstm1 = nn.LSTM(input_size=5, hidden_size=self.lstm_hidden_size, num_layers=2, batch_first=True, dropout=0.25)
        self.lstm2 = nn.LSTM(input_size=self.fc1_output_size+self.lstm_hidden_size, hidden_size=self.lstm_hidden_size, num_layers=2, batch_first=True, dropout=0.25)
        self.fc2 = nn.Linear(self.lstm_hidden_size, self.len_lexicon)
        
    def forward(self, image_input, language_input, internal1=None, internal2=None):
        bs = image_input.shape[0]
        sl = language_input.shape[1]
        if internal1:
            h1, c1 = internal1
        else:
            h1 = torch.zeros(2, bs, self.lstm_hidden_size).to(device)
            c1 = torch.zeros(2, bs, self.lstm_hidden_size).to(device)
        if internal2:
            h2, c2 = internal2
        else:
            h2 = torch.zeros(2, bs, self.lstm_hidden_size).to(device)
            c2 = torch.zeros(2, bs, self.lstm_hidden_size).to(device)
        image_output = self.fc1(self.cnn(image_input).view(bs, 512))
        image_output = image_output.repeat(1, sl).view(bs, sl, self.fc1_output_size)
        language_output, (h1, c1) = self.lstm1(self.embed(language_input), (h1, c1))
        concatenated = torch.cat([image_output, language_output], 2)
        lstm2_out, (h2, c2) = self.lstm2(concatenated, (h2, c2))
        out = self.fc2(lstm2_out)
        return out, (h1, c1), (h2, c2)
    
    def fit(self, total_iterations, optimizer, loss_fn, rate_decay, print_every=100):
        train_start_time = time.time()
        for i in range(total_iterations):
            batch_indices = np.random.choice(len(subsequences), size=batch_size)
            image_batch = images[image_indices[batch_indices]].reshape(-1, 1, 200, 200)/255
            measure_data_channel_batch = measure_data_channels[image_indices[batch_indices]].reshape(-1, 1, 200, 200)
            arr = np.concatenate([image_batch, measure_data_channel_batch], axis=1)
            sequence_batch = subsequences[batch_indices]
            arr = torch.Tensor(arr).type(torch.float).to(device)
            seq1 = torch.Tensor(sequence_batch[:, :-1]).type(torch.long).to(device)
            seq2 = torch.Tensor(sequence_batch[:, 1:]).type(torch.long).to(device)
            out, _, _ = self.forward(arr, seq1)
            out = out.view(-1, len_lexicon)
            targets = seq2.view(-1)
            loss = loss_fn(out, targets)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if i % print_every == 0:
                n = np.random.randint(len(subsequences))
                image_index = image_indices[n]
                image = images[image_index].reshape(1, 1, 200, 200)/255
                measure_data_channel = measure_data_channels[image_index].reshape(1, 1, 200, 200)
                arr = np.concatenate([image, measure_data_channel], axis=1)
                arr = torch.Tensor(arr).type(torch.float).to(device)
                pc = other_data['aux_data'][image_index]['pc']
                pc = ' '.join([ix_to_word[str(ix)] for ix in pc])
                pred_seq = self.predict(arr)
                pred_seq = ' '.join(pred_seq)
                with open('./log_preprocessed_model-2019-09-23.txt', 'a+') as f:
                    info_string = f"""
                    ----
                    iteration: {i}
                    time elapsed: {time.time() - train_start_time}
                    loss: {loss}
                    ----
                    pred: {pred_seq}
                    ----
                    true: {pc}
                    ----



                    """.replace('    ', '')
                    print(info_string)
                    f.write(info_string)
            if i % 5000 == 0 and i != 0:
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= rate_decay
                torch.save(self, f'./preprocessed_model_checkpoint_iteration_{i}-2019-09-23.pt')
            
             
    def predict(self, arr):
        self.eval()    
        with torch.no_grad():
            arr = arr.view(1,2, 200, 200)
            output_sequence = ['<START>']
            h1 = torch.zeros(2, 1, self.lstm_hidden_size).to(device)
            c1 = torch.zeros(2, 1, self.lstm_hidden_size).to(device)
            h2 = torch.zeros(2, 1, self.lstm_hidden_size).to(device)
            c2 = torch.zeros(2, 1, self.lstm_hidden_size).to(device)
            while output_sequence[-1] != '<END>' and len(output_sequence)<400:
                language_input = torch.Tensor([word_to_ix[output_sequence[-1]]]).type(torch.long).view(1, 1).to(device)
                out, (h1, c1), (h2, c2) = self.forward(arr, language_input, (h1, c1), (h2, c2))
                _, language_input = out[0, 0, :].max(0)
                output_sequence.append(ix_to_word[str(language_input.item())])
        self.train()
        return output_sequence

    def predict_from_image(self, path, measure_length, key_number):
        # path should be path to a png
        image = io.imread(path)/255
        # handle the conversion from rgb and rgba pngs
        if len(image.shape) == 3:
            if image.shape[2] == 3:
                image = rgb2gray(image)
            elif image.shape[3] == 4:
                image = rgb2gray(image[:, :, :3])
        image = transform.resize(image, (200, 200), cval=1)
        measure_channel = get_measure_data_channel(measure_length, key_number, 200, 200)
        arr = np.array([image, measure_channel])
        arr = torch.Tensor(arr).type(torch.float).to(device)
        return self.predict(arr)

In [0]:
model = Net(len_lexicon, lstm_hidden_size, fc1_output_size, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

In [42]:
model.fit(100, optimizer, loss_fn, 1, print_every=10)


----
iteration: 0
time elapsed: 1.1943106651306152
loss: 3.4831182956695557
----
pred: <START> type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type type

KeyboardInterrupt: ignored