In [1]:
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [3]:
import torch.nn as nn

In [33]:
class ResidualBlock(nn.Module):
  def __init__(self,in_channel,out_channel):
    super().__init__()
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.blocks = nn.Identity()
    self.activation = nn.ReLU()
    self.short_cut = nn.Identity()
  def forward(self,x):
    residual = x
    if self.in_channel != self.out_channel : self.short_cut(x)
    x = self.blocks(x)
    x+=residual
    x= self.activation(x)
    return x



In [35]:
print(ResidualBlock(3,64))

ResidualBlock(
  (blocks): Identity()
  (activation): ReLU()
  (short_cut): Identity()
)


In [39]:
class ResnetResidualBlock(ResidualBlock):
  def __init__(self,in_channel,out_channel,downsampling = 1):
    super().__init__(in_channel,out_channel)
    self.downsampling = downsampling
    self.short_cut = nn.Sequential(
        nn.Conv2d(self.in_channel,self.out_channel,1,self.downsampling),
        nn.BatchNorm2d(self.out_channel)
    )


In [40]:
print(ResnetResidualBlock(3,64))

ResnetResidualBlock(
  (blocks): Identity()
  (activation): ReLU()
  (short_cut): Sequential(
    (0): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


In [29]:
class ResnetBasicBlock(ResnetResidualBlock):
  def __init__(self,in_channel,out_channel,*args,**kwargs):
    super().__init__(in_channel,out_channel)
    self.blocks = nn.Sequential(
        nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=self.downsampling),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
        nn.Conv2d(in_channel,out_channel,kernel_size=3),
        nn.BatchNorm2d(out_channel)
    )

In [30]:
print(ResnetBasicBlock(3,64,do))

ResnetBasicBlock(
  (blocks): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (activation): ReLU()
  (short_cut): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
