In [49]:
# %load '/Users/jesusnavarro/Desktop/gan_video/core/main.py'
from torch import optim
from torch.utils.data import DataLoader
import torch.autograd
from torchvision.transforms import transforms
from torchvision.transforms import ToPILImage
import logging
from core.model import *
from core.UAVDataset import *

def setup_logger(logger_name):
    # format string
    LOG_FORMAT = "%(levelname)s - %(asctime)s - %(messages)s"
    logger = logging.getLogger(logger_name)
    
    # create formatter
    formatter = logging.Formatter(fmt=LOG_FORMAT)
    
    # create handler, streamhandler, and format
    file_handler = logging.FileHandler('training_log.txt', mode='w')
    file_handler.setFormatter(file_handler)
    screen_handler = logging.StreamHandler(stream=sys.stdout)
    screen_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.addHandler(screen_handler)
    
    return logger

def save_image(img, filename):
    img = img.detach()
    pil_img = ToPILImage()
    img = pil_img(img)
    img.save(filename)
    

if __name__ == '__main__':

    num_epochs = 30
    num_batch = 5
    
    # Create binary cross entropy function
    loss = torch.nn.BCELoss()
    batch_size = 5
    
    # Define transforms to apply to the data
    composed = transforms.Compose([Rescale(64), ToTensor()])
    
    # Define location of data pickle files
    # Batches of 32 sequential frame paths are grouped together in a pandas df
    working_dir = %pwd
    sub_dir = '/pickle_data/test_200_videos.pickle'
    
    # Instantiate DataLoader object
    transformed_uav_dataset = UAVDataset(pickle_path= working_dir + sub_dir,
                                         transform=composed)
    data_loader = DataLoader(transformed_uav_dataset,
                             batch_size=batch_size,
                             shuffle=True)
    num_batches = len(data_loader)
    
    # Instantiate generator and discriminator
    generator = VideoGen().float()
    discriminator = VideoDiscriminator().float()
    
    # Directory to save generator videos
    DIR_TO_SAVE = "./gen_videos/"
    if not os.path.exists(DIR_TO_SAVE):
        os.makedirs(DIR_TO_SAVE)
    sample_input = None
    sample_input_set = False

    loss_func = nn.CrossEntropyLoss()
    
    # Define optimizer for the discriminator and generator
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.002)
    g_optimizer = optim.Adam(generator.parameters(), lr=0.002)

    for eopch in range(num_epochs):

        for (iter_num, batch) in enumerate(data_loader):
            # reset gradients
            generator.zero_grad()
            discriminator.zero_grad()
            
            real_video = batch['video'] 
            real_video = real_video.unsqueeze(0).repeat(1, 1, 1, 1, 1)
            real_video = real_video.permute(0,2,1,3,4)
            
            real_labels = torch.tensor(np.ones(1, real_video.size(0)))
            fake_labels = torch.tensor(np.zeros(1, real_video.size()))
            
            if not iter_num % 3 == 0: # Train discriminator
                r_real_output = discriminator(real_video).squeeze()
                
                # Generate a fake video, detach parameters
                latent_z = torch.randn(1, 1, 1, 1, 100)
                fake_video = generator(latent_z).squeeze().detach()
                d_fake_output = discriminator(fake_video)
                
                # Compute real and fake loss
                d_fake_loss = loss_function(d_fake_output, fake_labels)
                d_real_loss = loss_function(d_real_output, real_labels)
                d_loss = d_real_loss + d_fake_loss
                
                # Update Gradient
                d_loss.backward()
                d_optimizer.step()
                
            else: # Train generator
                first_frame = real_video[:,:,0:1,:,:]
                latent_z = torch.randn(1,1,1,N,100)
                fake_videos = generator(latent_z)
                d_fake_outputs = discriminator(fake_videos).squeeze()
                gen_first_frame = fake_videos[:,:,0:1,:,:]
                reg_loss = torch.mean(torch.abs(first_frame - gen_first_frame)) * l1_lambda
                g_loss = loss_function(outputs, real_labels) + reg_loss
                g_loss.backward()
                g_optimizer.step()
                
                
                




                                               video
0  [[/Users/jesusnavarro/Desktop/gan_video/UAV123...
1  [[/Users/jesusnavarro/Desktop/gan_video/UAV123...
2  [[/Users/jesusnavarro/Desktop/gan_video/UAV123...
3  [[/Users/jesusnavarro/Desktop/gan_video/UAV123...
4  [[/Users/jesusnavarro/Desktop/gan_video/UAV123...


RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

'/Users/jesusnavarro/Desktop/gan_video'