In [None]:
import numpy as np
import random

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
from torch.distributions import Beta
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from torch.autograd import Function

import gnwrapper
import gym
from gym import logger as gymlogger
gymlogger.set_level(30)

from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
import wandb

%matplotlib inline
import matplotlib.pyplot as plt

import car_racing as cr

import os
import glob

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

print(device)

if (os.path.exists("./output_r")) == False:
    os.mkdir("output_r")
    
for epoch in range (3000):
    files = glob.glob("./output_r/*.png")

    for f in files:
        os.remove(f)

In [None]:
class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None  

class DSN(nn.Module):
    def __init__(self, num_out = 2):
        super(DSN, self).__init__()
        
        
        self.sketch = nn.Sequential(
            nn.Conv3d(4, 4, kernel_size=(1, 1, 3), stride=1),
            nn.ReLU(),
        )
        self.source_encoder = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
        )
        
        self.target_encoder = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
        )
        
        self.shared_encoder = nn.Sequential(
            nn.Conv2d(4, 8, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 48, kernel_size=3, stride=2),
            nn.ReLU(),
        )

In [None]:
test_tensor = torch.rand(64, 4, 96, 96, 3)

dsn = DSN()

sketch = dsn.sketch(test_tensor)
sketch = torch.squeeze(sketch)
print(sketch.size())

source_encode = dsn.source_encoder(sketch)
print(source_encode.size())

target_encode = dsn.target_encoder(sketch)
print(target_encode.size())

share_encode = dsn.shared_encoder(sketch)
print(share_encode.size())
