In [30]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
from PIL import Image

In [29]:
def preprocess(PIL_Image , image_shape):
    trans = torchvision.transfrom.Compose(
                torchvision.transforms.Resize(image_shape),
                torchvision.transforms.toTensor()
                )
    return trans(PIL_Image)

In [44]:
def vgg_block(in_channel,out_channel,i,kernel_size =3,stride = 1,padding = 1):
    net_list = []
    if i == 1 : 
        net_list.append(nn.conv2d(in_channel,out_channel,kernel_size=kernel_size,stride = stride,padding = padding))
        net_list.append(nn.conv2d(out_channel,out_channel,kernel_size=kernel_size,stride = stride,padding = padding))
    elif i == 2:
        net_list.append(nn.conv2d(in_channel,out_channel,kernel_size=kernel_size,stride = stride,padding = padding))
        net_list.append(nn.conv2d(out_channel,out_channel,kernel_size=kernel_size,stride = stride,padding = padding))
    else :
        net_list.append(nn.conv2d(in_channel,out_channel,kernel_size=kernel_size,stride = stride,padding = padding))
        net_list.append(nn.conv2d(out_channel,out_channel,kernel_size=kernel_size,stride = stride,padding = padding))
        net_list.append(nn.conv2d(out_channel,out_channel,kernel_size=kernel_size,stride = stride,padding = padding))
    
    net_list.append(kernel_size=2, stride=2).append(nn.MaxPool2d(kernel_size=2, stride=2,padding = 0))
    return nn.Sequential(*net_list)

In [45]:
conv_arch = ((1, 3, 64), (2, 64, 128), (3, 128, 256), (4, 256, 512), (5, 512, 512))

In [46]:
class conv_layers(nn.Module):
    def ___init__(self):
        super(conv_layers,self).__init__()
        net = nn.Sequential()
        for i,(num_convs,in_channels,out_channels) in enumerate(conv_arch):
            in_channels *= 2**num_convs
            net.add_module("vgg_block_" + str(num_convs+1), vgg_block(in_channels, out_channels,num_convs))
        self.fc = net
    
    def forward(self, x):
        return self.fc(x)

In [47]:
PRN_in_channels,PRN_out_channels = 512,18
class PRN(nn.Module):
    def __init__(self):
        super(conv_layers,self).__init__()
        self.fc = nn.con2d(in_channels,in_channels,kernel_size = 3, padding =1,strides = 1)
        self.prototxt = nn.con2d(PRN_in_channels, PRN_out_channels,kernel_size =1,padding=1,strides =1)
        reshape()
        self.softmax = nn.softmax()
        reshape()
    
    def forward(X):
        return reshape(self.softmax(reshape(self.prototxt(self.fc(X)))))

In [48]:
PRN_bouding_box_out_channels,PRN_bouding_box_in_channels = 36,512
class PRN_bouding_box(nn.Module):
    def __init__(self):
        super(conv_layers,self).__init__()
        self.fc = nn.con2d(PRN_bouding_box_in_channels,PRN_bouding_box_in_channels,kernel_size = 3, padding =1,strides = 1)
        self.box = nn.con2d(PRN_bouding_box_in_channels, PRN_bouding_box_out_channels,kernel_size =1,padding=1,strides =1)
    
    def forward(X):
        return self.box(self.fc(X))