In [1]:
import torch
from torch import nn
import pytorch_lightning as pl

In [27]:
class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lr = nn.LeakyReLU()
        self.conv1d = nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=3,
                                stride=1,
                                padding=1)

        self.mp = nn.MaxPool1d(kernel_size=2)
        pass

    def forward(self, x):
        x = self.conv1d(x)
        x = self.lr(x)

        ctx = torch.clone(x)
        x = self.mp(x)       

        return x, ctx

class UBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=3,
                                stride=1,
                                padding=1)

        self.up_sampling = nn.Upsample(scale_factor=2)

        self.blck = nn.Sequential(
            nn.Conv1d(in_channels=out_channels,
                      out_channels=out_channels,
                      kernel_size=2), nn.LeakyReLU())


    def add_ctx(self, x, ctx):
        # crop context (y)
        d_shape = [(ctx.shape[i]-x.shape[i])//2 for i in range(len(x.shape))]
        crop = ctx[d_shape[0]:d_shape[0]+x.shape[0],
            d_shape[1]:d_shape[1]+x.shape[1],
            d_shape[2]:d_shape[2]+x.shape[2],
        ]
        #concatenate
        out = torch.cat([x, crop], dim=1, keepdims=True) 
        return out



    def forward(self, x, ctx):

        x = self.up_sampling(x)
        x = self.conv1d(x)
        x = self.blck(x)
        x = self.blck(x)
 
        return x




In [28]:
t = torch.randn(2, 4, 2)
dblock = DBlock(in_channels=4, out_channels=8)
ublock = UBlock(in_channels=8, out_channels=4)

out, ctx = dblock(t)
print(out.shape)
print(ctx.shape)

out = ublock.add_ctx(out, ctx)

print(out.shape)

AttributeError: module 'torch' has no attribute 'copy'

In [8]:
a = torch.randn(2,3,1)
print(a)
print(a[-1:2, ...])

tensor([[[ 0.7045],
         [ 1.5843],
         [-0.7899]],

        [[ 0.3869],
         [-1.7405],
         [-0.3557]]])
tensor([[[ 0.3869],
         [-1.7405],
         [-0.3557]]])
