- Based on: https://arxiv.org/pdf/1708.00838v1.pdf
- Code: https://github.com/kunalrdeshmukh/End-to-end-compression

In [68]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

### Data

In [25]:
def txt_to_matrix(filename, line_skip = 5):
    f = open (filename, 'r')
    
    # Lineskip, cleaning, conversion
    data = f.readlines()[line_skip:]
    
    data = np.asarray(
        [l.replace("\n", "").split() for l in data]
    ).astype(np.float32)
    
    data[data > 100] = 0
    
    return data

In [26]:
def get_dep_time_step(root, index):
    
    dep = txt_to_matrix(root + 'mini-decoded-' + index + '.DEP')

    return np.array([dep])

In [27]:
rootdir = '../datasets/530-9m1s/'
timesteps = []
ignore = [".DS_Store", ".", ".."]

x = 0
offset = 0
ceiling = 530

# Read all dirs and process them
for path in tqdm.tqdm(range(ceiling - offset)):
    
    # Processing
    file = rootdir
    timesteps.append(
        get_dep_time_step(
            file, ("{:04d}".format(x + offset))
        )
    )
    x += 1
        
timesteps = np.asarray(timesteps).astype(np.float32)

100%|██████████| 530/530 [00:51<00:00, 10.33it/s]


### Network

In [29]:
timesteps.shape

(530, 1, 336, 341)

In [103]:
CHANNELS = 1
HEIGHT = timesteps.shape[2]
WIDTH = timesteps.shape[2]
EPOCHS = 200
LOG_INTERVAL = 10
BATCH_SIZE = 16

In [90]:
class Interpolate(nn.Module):
    def __init__(self, size, mode):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.size = size
        self.mode = mode
        
    def forward(self, x):
        x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
        return x

In [91]:
class End_to_end(nn.Module):
  def __init__(self):
    super(End_to_end, self).__init__()
    
    # Encoder
    self.conv1 = nn.Conv2d(CHANNELS, 64, kernel_size=3, stride=1, padding=1)
    self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0)
    self.bn1 = nn.BatchNorm2d(64, affine=False)
    self.conv3 = nn.Conv2d(64, CHANNELS, kernel_size=3, stride=1, padding=1)
    
    # Decoder
    self.interpolate = Interpolate(size=HEIGHT, mode='bilinear')
    self.deconv1 = nn.Conv2d(CHANNELS, 64, 3, stride=1, padding=1)
    self.deconv2 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.bn2 = nn.BatchNorm2d(64, affine=False)
    
    self.deconv_n = nn.Conv2d(64, 64, 3, stride=1, padding=1)
    self.bn_n = nn.BatchNorm2d(64, affine=False)

    
    self.deconv3 = nn.ConvTranspose2d(64, CHANNELS, 3, stride=1, padding=1)
    
    self.relu = nn.ReLU()
  
  def encode(self, x):
    out = self.relu(self.conv1(x))
    out = self.relu(self.conv2(out))
    out = self.bn1(out)
    return self.conv3(out)
    
  
  def reparameterize(self, mu, logvar):
    pass
  
  def decode(self, z):
    upscaled_image = self.interpolate(z)
    out = self.relu(self.deconv1(upscaled_image))
    out = self.relu(self.deconv2(out))
    out = self.bn2(out)
    for _ in range(5):
      out = self.relu(self.deconv_n(out))
      out = self.bn_n(out)
    out = self.deconv3(out)
    final = upscaled_image + out
    return final,out,upscaled_image

    
  def forward(self, x):
    com_img = self.encode(x)
    final,out,upscaled_image = self.decode(com_img)
    return final, out, upscaled_image, com_img, x

In [92]:
CUDA = torch.cuda.is_available()
if CUDA:
  model = End_to_end().cuda()
else :
  model = End_to_end()

print("GPU available ? "+str(CUDA))
  
optimizer = optim.Adam(model.parameters(), lr=1e-3)

GPU available ? False


### Training

In [93]:
def loss_function(final_img,residual_img,upscaled_img,com_img,orig_img):
    
    com_loss = nn.MSELoss(size_average=False)(orig_img, final_img)
    rec_loss = nn.MSELoss(size_average=False)(residual_img,orig_img-upscaled_img)

    return com_loss + rec_loss

In [None]:
for epoch in range(EPOCHS):
    x = 0
    
    batches = int(timesteps.shape[0] / BATCH_SIZE)

    # TODO: create batches and train the net with them  
    for b in range(batches):

        model.train()
        train_loss = 0

        optimizer.zero_grad()
        data = torch.Tensor(timesteps[x:x+BATCH_SIZE, :, :, :WIDTH])
            
        if CUDA:
          final, residual_img, upscaled_image, com_img, orig_im = model(data.cuda())
        else :
          final, residual_img, upscaled_image, com_img, orig_im = model(data)

        loss = loss_function(final, residual_img, upscaled_image, com_img, orig_im)
        
        loss.backward()
        
        train_loss += loss.item()
        
        optimizer.step()

        print('====> Epoch: {} Batch {} Average loss: {:.4f}'.format(
          epoch, b, train_loss / len(timesteps)))
        
        if x % LOG_INTERVAL == 0:
            plt.matshow(final.detach().numpy()[0, 0])
            plt.show()
        
        x += 1

torch.Size([16, 1, 336, 336])
