In [1]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
from data_loader_ddr_fixed_length import get_loader 
from build_vocab_ddr import Vocabulary
from model_fixed_length import EncoderCNN, DecoderRNN
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
import nltk

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# From Image Captioning Code

In [2]:
class Args:
    model_path = 'ddrmodels/'
    crop_size = 224
    vocab_path = 'data/vocab_ddr.pkl'
    image_dir = 'data/resizedddr'
    caption_path = 'data/annotations/spectrogram_2.csv'
    log_step = 10
    save_step = 100
    embed_size = 256
    hidden_size = 512
    num_layers = 3
    num_epochs = 1
    batch_size = 2
    num_workers = 2
    learning_rate = 0.001
    
args=Args()

In [3]:
def main(args):
    CAPTION_LENGTH = 1000
    
    
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    
    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([ 
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    
    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)
    
    # Build data loader
    data_loader = get_loader(args.image_dir, args.caption_path, vocab, 
                             transform, args.batch_size,
                             shuffle=True, num_workers=args.num_workers) 

    # Discriminator
    d_encoder = EncoderCNN(args.embed_size).to(device)
    d_decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)
    
    D = nn.Sequential(
        nn.Linear(CAPTION_LENGTH, 256),
        #nn.Linear(len(vocab), 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1),
        nn.Sigmoid())
       
    D = D.to(device)

    # Generator
    
    g_encoder = EncoderCNN(args.embed_size).to(device)
    g_decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)

    # Binary cross entropy loss and optimizer
    criterion = nn.BCELoss()
    d_params = list(D.parameters()) + list(d_encoder.parameters()) + list(d_decoder.parameters())
    d_optimizer = torch.optim.Adam(d_params, lr=0.0002)
    
    g_params = list(g_decoder.parameters()) + list(g_encoder.parameters())
    g_optimizer = torch.optim.Adam(g_params, lr=0.0002)

    def denorm(x):
        out = (x + 1) / 2
        return out.clamp(0, 1)

    def reset_grad():
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()

    # Start training
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(data_loader):
            
            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

            # Create the labels which are later used as input for the BCE loss
            real_labels = torch.ones(args.batch_size, 1).to(device)
            fake_labels = torch.zeros(args.batch_size, 1).to(device)

            # ================================================================== #
            #                      Train the discriminator                       #
            # ================================================================== #

            # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
            # Second term of the loss is always zero since real_labels == 1
            #outputs = D(images)
            #outputs = decoder(features, captions, lengths)
            print(lengths)
            print(images.shape)
            features = d_encoder(images)
            print(features.shape)
            outputs = d_decoder(features, captions, lengths)
            print(outputs.shape)
            print(captions.shape)
            
            d_outputs = D(outputs)
            print("d_outputs: ", d_outputs.shape, "Real Labels: ", real_labels.shape)
            d_loss_real = criterion(d_outputs, real_labels)
            real_score = outputs

            # Compute BCELoss using fake images
            # First term of the loss is always zero since fake_labels == 0
            z = torch.randn(args.batch_size, latent_size).to(device)
            fake_images = G(z)
            outputs = D(fake_images)
            d_loss_fake = criterion(outputs, fake_labels)
            fake_score = outputs

            # Backprop and optimize
            d_loss = d_loss_real + d_loss_fake
            reset_grad()
            d_loss.backward()
            d_optimizer.step()

            # ================================================================== #
            #                        Train the generator                         #
            # ================================================================== #

            # Compute loss with fake images
            z = torch.randn(args.batch_size, latent_size).to(device)
            fake_images = G(z)
            outputs = D(fake_images)

            # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
            # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
            g_loss = criterion(outputs, real_labels)

            # Backprop and optimize
            reset_grad()
            g_loss.backward()
            g_optimizer.step()

            if (i+1) % 200 == 0:
                print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                      .format(epoch, args.num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                              real_score.mean().item(), fake_score.mean().item()))

        # Save real images
        if (epoch+1) == 1:
            images = images.reshape(images.size(0), 1, 28, 28)
            save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

        # Save sampled images
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

    # Save the model checkpoints 
    torch.save(G.state_dict(), 'G.ckpt')
    torch.save(D.state_dict(), 'D.ckpt')

if __name__ == '__main__':
    main(args)

[1000, 1000]
torch.Size([2, 3, 224, 224])
torch.Size([2, 256])
torch.Size([2, 1000, 256])
torch.Size([2, 256])


RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 3 and 2 at c:\a\w\1\s\tmp_conda_3.7_061434\conda\conda-bld\pytorch_1544163540495\work\aten\src\thc\generic/THCTensorMath.cu:74