In [1]:
import sys
sys.path.insert(0,'/home/mohit.kumargupta/deep_boltzmann')
import numpy as np
import mdtraj as md
import torch
import torch.nn as nn
from pathlib import Path
import torchvision.transforms as transforms
import torch.optim as optim
from math import *
from random import gauss,seed
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt
import seaborn

In [2]:
class DensityEstimator(nn.Module):
    def __init__(self,dimer_atoms,output_dim):
        super(DensityEstimator, self).__init__()
        self.dimer_atoms = dimer_atoms
        self.hidden_dim = 256
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.fc1 = nn.Linear(self.dimer_atoms, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, int(self.hidden_dim/2) )
        self.fc3 = nn.Linear(int(self.hidden_dim/2), int(self.hidden_dim/4))

    def forward(self,x):
        out = self.relu( self.fc1(x) )
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        out = 1160*torch.softmax(out,1)
        return out


In [3]:
class PixShuf(nn.Module):
    def __init__(self, up):
        super(PixShuf,self).__init__()
        self.upsample = up
    def forward(self,inp):
        out = inp.reshape((inp.shape[0],inp.shape[1]//(self.upsample**3), self.upsample,self.upsample, self.upsample, inp.shape[2],inp.shape[3],inp.shape[4]))
        out = out.permute(0,1,5,2,6,3,7,4)
        out = out.reshape((out.shape[0],out.shape[1],inp.shape[2]*self.upsample,inp.shape[3]*self.upsample,inp.shape[4]*self.upsample))
        return out

In [4]:
class RESNET(nn.Module):
    def __init__(self,numberOfLayers= 18, dim = 4):
        super(RESNET, self).__init__()
        self.conv_input = nn.Conv3d(in_channels=1, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.residual = self.make_layer(Residual_Block, numberOfLayers)
        self.upscale2x = nn.Sequential(
            nn.Conv3d(in_channels=64, out_channels=256*2, kernel_size=3, stride=1, padding=1, bias=False),
            PixShuf(2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv_output = nn.Conv3d(in_channels=64, out_channels=1, kernel_size=5, stride=1, padding=2, bias= False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.conv_input(x))
        residual = out
        out = self.residual(out)
        out = torch.add(out,residual)
        #self.relu(F.pixel_shuufle)
        #out = self.relu( F.pixel_shuffle( self.UpScaleConv(out), 2) )
        out = self.upscale2x(out)
        out = self.conv_output(out)
        return out

In [5]:
class Residual_Block(nn.Module):
    def __init__(self):
        super(Residual_Block, self).__init__()

        self.conv1 = nn.Conv3d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.in1 = nn.InstanceNorm3d(64, affine=True)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv3d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.in2 = nn.InstanceNorm3d(64, affine=True)

    def forward(self, x):
        identity_data = x
        output = self.in1( self.relu(self.conv1(x)) )
        output = self.in2( self.conv2(output) )
        output = torch.add(output,identity_data)
        return output

In [6]:
class CombineModel(nn.Module):
    def __init__(self,model1,model2,model3):
        super(CombineModel, self).__init__()
        self.model1 = model1
        self.model2 = model2
        self.model3 = model3
    
    def forward(self,inp):
        out = self.model1(inp)
        out = torch.reshape(out,(inp.shape[0],1,4,4,4) )
        print(out.shape)
        out = self.model2(out)
        out = self.model3(out)
        print(out.shape)
        return out

In [7]:
class ConfigurationModel(nn.Module):
    
    def __init__(self,model1):
        super(ConfigurationModel,self).__init__()
        self.model1 = model1
        self.hidden_dim = 16*16
        self.out_dim = 11
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.fc1 = nn.Linear(256*16, 100)
        self.fc2 = nn.Linear(100, 1164*3)
    
    def forward(self,inp):
        out = self.model1(inp)
        out = out.reshape((inp.shape[0],256*16))
        out = self.relu(self.fc1(out))
        out = self.relu(self.fc2(out))
        print(out.shape)
        return out

In [8]:
sz =16

In [9]:
boxes =[]
x=0
while(x<3.3):
    y=0.0
    while(y<3.3):
        z=0.0
        while(z<3.31):
            boxes.append([x+ 3.4/(2*sz),y+3.4/(2*sz),z+3.4/(2*sz)])
            z+=3.4/sz
        y+= 3.4/sz
    x+= 3.4/sz

In [10]:
boxes = np.array(boxes)

In [11]:
vol = (3.4/sz) **3
box_size= 3.4

In [12]:
m1= DensityEstimator(42*3,64).to('cuda')
m2 = RESNET(9,dim=4).to('cuda')
m3 = RESNET(9,dim=8).to('cuda')

In [13]:
comModel = CombineModel(m1,m2,m3).to('cuda')

In [14]:
configModel = ConfigurationModel(comModel).to('cuda')

In [15]:
perm = np.arange(160)
validPerm = np.arange(20)+160

In [16]:
adOptimizer = optim.Adam(configModel.parameters(),lr =0.0001)

In [17]:
def findNeighbourList():
    global boxes
    bx = []
    for p,eachBox in enumerate(boxes):
        box = []
        for i in range(5) :
            for j in range(5):
                for k in range(5):
                    x= (int(p/256) + i-2)%16
                    y=  ( int( (p%256)/16 ) +j-2)%16
                    z = ( int(p%16) + k-2 ) %16
                    while(x<0):
                        x+=16
                    while(y<0):
                        y+=16
                    while(z<0):
                        z+=16
                    box.append(np.array(boxes[x*256 +y*16 +z] ) )
        bx.append(np.array(box))
    return bx

In [18]:
NeighbourList = np.array(findNeighbourList())
NeighbourList = torch.from_numpy(NeighbourList).type('torch.FloatTensor').to('cuda')

In [None]:
def otherDMapCalc(temp):
    global NeighbourList
    DensityMap = None
    box_size =3.4
    # FUnction to Look at and NeighbourList Function is implemented above
    # temp is used for snapshots
    %%time
    for i,snapshot in enumerate(temp):
        print(i)
        snapshot= torch.from_numpy(snapshot).to('cuda')
        myBox = (torch.floor((snapshot+box_size)*sz/box_size)%16 ).type('torch.IntTensor')
        myBox = myBox[:,0]*256 + myBox[:,1]*16 + myBox[:,2] #Finding Box Number

        MyList = []
        for box in myBox:
            MyList.append(NeighbourList[box].detach().to('cpu').numpy())

        MyList = torch.from_numpy(np.array(MyList)).type('torch.FloatTensor')
        MyList = MyList.to('cuda')

        box_size =3.4
        dx= torch.abs( (snapshot.unsqueeze(1).expand(-1,125,-1) -MyList) )
        dist = torch.norm( dx - box_size*torch.floor(0.5+dx/box_size),dim=2) #calculating distance using minimum image
        prob = torch.exp(-0.5 * torch.mul(dist/0.2,dist/0.2))/((2*3.14*0.04)**1.5 )

        MyList = torch.floor( (MyList*sz+ 0.01)/box_size )
        MyList = (MyList[:,:,0]*256 + MyList[:,:,1]*16  + MyList[:,:,2]).type('torch.IntTensor')
        print(MyList.shape,MyList[0,0])

        dMap = torch.ones(4096) - 1.0
        break
        for j in range(4096):
                idx = (MyList==j).nonzero()

        '''
        newList = []
        for j in range(MyList.shape[0]):
            for k in range(125):
                newList.append(j)
        newList = torch.from_numpy(np.array(newList))
        print(MyList[1,0],MyList.flatten()[125])
        out = torch.stack((newList.type('torch.IntTensor'),MyList.flatten()))
        dMap = torch.sparse.FloatTensor(out.type('torch.cuda.LongTensor'),prob.flatten())
        dMap = torch.sparse.sum(dMap,0)*vol 
        '''
        '''
        break
        newList = None
        for (j,atom) in enumerate(MyList):
            if(newList is None):
                newList = torch.stack( (torch.ones(125,dtype = torch.int).new_full((1,125),j).squeeze(0), MyList[j]) ).permute(1,0)
            else:
                newList= torch.cat( (torch.stack( (torch.ones(1,dtype = torch.int).new_full((1,125),j).squeeze(0) , MyList[j]),0 ).permute(1,0), newList ) )
        #newList = torch.from_numpy(np.array(newList)).to('cuda').type('torch.cuda.LongTensor')
        print(newList.t().shape)
        d = torch.sparse.FloatTensor(newList.t().type('torch.cuda.LongTensor'),prob.flatten())
        #print(torch.sparse.sum(d,0)*vol)
        break
        '''

<font color='red'>To Look at Lost Function </font>

In [28]:
def DensityMapCalc(traj):
    global NeighbourList
    DensityMap = None
    box_size =3.4
    for i,snapshot in enumerate(traj):
        print(i)
        snapshot= snapshot.to('cuda')
        myBox = (torch.floor((snapshot+box_size)*sz/box_size)%16 ).type('torch.IntTensor')
        myBox = myBox[:,0]*256 + myBox[:,1]*16 + myBox[:,2]

        MyList = []
        for box in myBox:
            MyList.append(NeighbourList[box].detach().to('cpu').numpy())

        MyList = torch.from_numpy(np.array(MyList)).type('torch.FloatTensor')
        MyList = MyList.to('cuda')
        
        dx= torch.abs( (snapshot.unsqueeze(1).expand(-1,125,-1) -MyList) )
        dist = torch.norm( dx - box_size*torch.floor(0.5+dx/box_size),dim=2)
        prob = torch.exp(-0.5 * torch.mul(dist/0.2,dist/0.2))/((2*3.14*0.04)**1.5 )

        MyList = torch.floor( (MyList*sz+ 0.01)/box_size )
        MyList = (MyList[:,:,0]*256 + MyList[:,:,1]*16  + MyList[:,:,2]).type('torch.IntTensor')
        print(MyList.max(),MyList.shape)
        break
        newList = []
        for j in range(1064):
            idx = (MyList==j).nonzero()
            
        for j in range(MyList.shape[0]):
            for k in range(125):
                newList.append(j)
        newList = torch.from_numpy(np.array(newList))
        #print(MyList[1,0],MyList.flatten()[125])
        out = torch.stack((newList.type('torch.IntTensor'),MyList.flatten()))
        dMap = torch.sparse.FloatTensor(out.type('torch.cuda.LongTensor'),prob.flatten()).to_dense()
        dMap = dMap.sum(0)*vol
        print(dMap.shape,dMap)
        if(DensityMap is None):
            DensityMap = dMap.unsqueeze(0)
        else:
            DensityMap = torch.cat((DensityMap,dMap.unsqueeze(0)))
    return DensityMap

In [20]:
def DensitylossFunction(out,sample):
    Density = DensityMapCalc(out.reshape((out.shape[0],1164,3)))
    print(sample.shape,Density.shape)
    return nn.MSELoss()(Density,sample)

In [21]:
TrajDirectory = "/home/mohit.kumargupta/confALA4/"
TrajFile = TrajDirectory +  "traj_comp.xtc"
TopFile = TrajDirectory + "ala4_amber_atA.gro"

In [22]:
xtObject = md.load(TrajFile, top = TopFile)

In [23]:
Epochs = 100

In [24]:
for epoch in range(Epochs):
    print(epoch)
    np.random.shuffle(perm)
    i = 0
    totalLoss = 0.0
    for eachVal in perm:
        if(eachVal == 31):
            continue
        print(i,eachVal)
        adOptimizer.zero_grad()
        sample = torch.from_numpy(np.load('ProteinData/ProteinData' + str(eachVal)+ '.npy')[0] ).type('torch.FloatTensor').to('cuda')
        mdObj = torch.from_numpy(xtObject.xyz[eachVal*500:(eachVal+1)*500,:42].reshape(500,42*3)).type('torch.FloatTensor')
        print(sample.sum()/500)
        out = configModel(mdObj.to('cuda'))
        print(out.shape)
        loss = DensitylossFunction(out,sample)
        print(loss)
        #loss = (torch.abs(out-sample)/sample).mean()
        
        mdObj = None
        sample = None
        out = None
        print(loss.item())
        totalLoss += loss.item()
        i+=1
        loss.backward()
        adOptimizer.step( )
    print(totalLoss)

0
0 25
tensor(1159., device='cuda:0')
torch.Size([500, 1, 4, 4, 4])
torch.Size([500, 1, 16, 16, 16])
torch.Size([500, 3492])
torch.Size([500, 3492])
0
torch.Size([4096]) tensor([53.4556, 27.1016,  6.2572,  ...,  0.9862, 10.4174, 32.8487],
       device='cuda:0', grad_fn=<MulBackward0>)
1
torch.Size([4096]) tensor([53.4477, 27.1020,  6.2585,  ...,  0.9881, 10.4163, 32.8438],
       device='cuda:0', grad_fn=<MulBackward0>)
2
torch.Size([4096]) tensor([53.4429, 27.1019,  6.2620,  ...,  0.9864, 10.4165, 32.8426],
       device='cuda:0', grad_fn=<MulBackward0>)
3
torch.Size([4096]) tensor([53.4430, 27.1001,  6.2610,  ...,  0.9863, 10.4166, 32.8439],
       device='cuda:0', grad_fn=<MulBackward0>)
4
torch.Size([4096]) tensor([53.4384, 27.1001,  6.2622,  ...,  0.9864, 10.4163, 32.8416],
       device='cuda:0', grad_fn=<MulBackward0>)
5
torch.Size([4096]) tensor([53.4391, 27.1015,  6.2625,  ...,  0.9864, 10.4151, 32.8396],
       device='cuda:0', grad_fn=<MulBackward0>)
6
torch.Size([4096]) te

torch.Size([4096]) tensor([53.4025, 27.0824,  6.2609,  ...,  0.9879, 10.4151, 32.8229],
       device='cuda:0', grad_fn=<MulBackward0>)
61
torch.Size([4096]) tensor([53.3809, 27.0791,  6.2661,  ...,  0.9881, 10.4124, 32.8095],
       device='cuda:0', grad_fn=<MulBackward0>)
62
torch.Size([4096]) tensor([53.3666, 27.0783,  6.2732,  ...,  0.9846, 10.4104, 32.8001],
       device='cuda:0', grad_fn=<MulBackward0>)
63
torch.Size([4096]) tensor([53.3722, 27.0769,  6.2695,  ...,  0.9864, 10.4110, 32.8038],
       device='cuda:0', grad_fn=<MulBackward0>)
64
torch.Size([4096]) tensor([53.3749, 27.0766,  6.2686,  ...,  0.9864, 10.4118, 32.8060],
       device='cuda:0', grad_fn=<MulBackward0>)
65
torch.Size([4096]) tensor([53.3729, 27.0818,  6.2741,  ...,  0.9856, 10.4109, 32.8027],
       device='cuda:0', grad_fn=<MulBackward0>)
66
torch.Size([4096]) tensor([53.3690, 27.0774,  6.2710,  ...,  0.9864, 10.4112, 32.8030],
       device='cuda:0', grad_fn=<MulBackward0>)
67
torch.Size([4096]) tensor([

torch.Size([4096]) tensor([53.3570, 27.0658,  6.2607,  ...,  0.9909, 10.4127, 32.8035],
       device='cuda:0', grad_fn=<MulBackward0>)
121
torch.Size([4096]) tensor([53.3565, 27.0656,  6.2636,  ...,  0.9873, 10.4124, 32.8019],
       device='cuda:0', grad_fn=<MulBackward0>)
122
torch.Size([4096]) tensor([53.3531, 27.0649,  6.2632,  ...,  0.9892, 10.4117, 32.7985],
       device='cuda:0', grad_fn=<MulBackward0>)
123
torch.Size([4096]) tensor([53.3617, 27.0633,  6.2595,  ...,  0.9891, 10.4130, 32.8048],
       device='cuda:0', grad_fn=<MulBackward0>)
124
torch.Size([4096]) tensor([53.3561, 27.0599,  6.2579,  ...,  0.9908, 10.4121, 32.8021],
       device='cuda:0', grad_fn=<MulBackward0>)
125
torch.Size([4096]) tensor([53.3564, 27.0578,  6.2566,  ...,  0.9906, 10.4117, 32.8027],
       device='cuda:0', grad_fn=<MulBackward0>)
126
torch.Size([4096]) tensor([53.3512, 27.0567,  6.2602,  ...,  0.9889, 10.4108, 32.7982],
       device='cuda:0', grad_fn=<MulBackward0>)
127
torch.Size([4096]) t

torch.Size([4096]) tensor([53.3928, 27.0596,  6.2497,  ...,  0.9857, 10.4135, 32.8227],
       device='cuda:0', grad_fn=<MulBackward0>)
181
torch.Size([4096]) tensor([53.3908, 27.0575,  6.2467,  ...,  0.9865, 10.4143, 32.8193],
       device='cuda:0', grad_fn=<MulBackward0>)
182
torch.Size([4096]) tensor([53.3949, 27.0597,  6.2473,  ...,  0.9865, 10.4119, 32.8190],
       device='cuda:0', grad_fn=<MulBackward0>)
183
torch.Size([4096]) tensor([53.3948, 27.0617,  6.2485,  ...,  0.9865, 10.4119, 32.8191],
       device='cuda:0', grad_fn=<MulBackward0>)
184
torch.Size([4096]) tensor([53.3829, 27.0631,  6.2535,  ...,  0.9850, 10.4075, 32.8091],
       device='cuda:0', grad_fn=<MulBackward0>)
185
torch.Size([4096]) tensor([53.3825, 27.0645,  6.2565,  ...,  0.9842, 10.4076, 32.8097],
       device='cuda:0', grad_fn=<MulBackward0>)
186
torch.Size([4096]) tensor([53.3945, 27.0687,  6.2565,  ...,  0.9841, 10.4092, 32.8169],
       device='cuda:0', grad_fn=<MulBackward0>)
187
torch.Size([4096]) t

torch.Size([4096]) tensor([53.3608, 27.0640,  6.2653,  ...,  0.9885, 10.4120, 32.8011],
       device='cuda:0', grad_fn=<MulBackward0>)
241
torch.Size([4096]) tensor([53.3656, 27.0624,  6.2618,  ...,  0.9889, 10.4173, 32.8078],
       device='cuda:0', grad_fn=<MulBackward0>)
242
torch.Size([4096]) tensor([53.3566, 27.0625,  6.2624,  ...,  0.9910, 10.4165, 32.8029],
       device='cuda:0', grad_fn=<MulBackward0>)
243
torch.Size([4096]) tensor([53.3454, 27.0625,  6.2662,  ...,  0.9896, 10.4147, 32.7950],
       device='cuda:0', grad_fn=<MulBackward0>)
244
torch.Size([4096]) tensor([53.3594, 27.0651,  6.2630,  ...,  0.9892, 10.4162, 32.8036],
       device='cuda:0', grad_fn=<MulBackward0>)
245
torch.Size([4096]) tensor([53.3582, 27.0642,  6.2608,  ...,  0.9914, 10.4172, 32.8036],
       device='cuda:0', grad_fn=<MulBackward0>)
246
torch.Size([4096]) tensor([53.3557, 27.0633,  6.2591,  ...,  0.9916, 10.4157, 32.8008],
       device='cuda:0', grad_fn=<MulBackward0>)
247
torch.Size([4096]) t

torch.Size([4096]) tensor([53.3956, 27.0586,  6.2500,  ...,  0.9883, 10.4193, 32.8297],
       device='cuda:0', grad_fn=<MulBackward0>)
301
torch.Size([4096]) tensor([53.3861, 27.0544,  6.2503,  ...,  0.9883, 10.4185, 32.8260],
       device='cuda:0', grad_fn=<MulBackward0>)
302
torch.Size([4096]) tensor([53.3933, 27.0539,  6.2472,  ...,  0.9884, 10.4210, 32.8345],
       device='cuda:0', grad_fn=<MulBackward0>)
303
torch.Size([4096]) tensor([53.3744, 27.0560,  6.2547,  ...,  0.9884, 10.4172, 32.8199],
       device='cuda:0', grad_fn=<MulBackward0>)
304
torch.Size([4096]) tensor([53.3802, 27.0546,  6.2519,  ...,  0.9883, 10.4183, 32.8237],
       device='cuda:0', grad_fn=<MulBackward0>)
305
torch.Size([4096]) tensor([53.3963, 27.0524,  6.2454,  ...,  0.9869, 10.4221, 32.8376],
       device='cuda:0', grad_fn=<MulBackward0>)
306
torch.Size([4096]) tensor([53.3863, 27.0539,  6.2487,  ...,  0.9883, 10.4201, 32.8303],
       device='cuda:0', grad_fn=<MulBackward0>)
307
torch.Size([4096]) t

torch.Size([4096]) tensor([53.3836, 27.0583,  6.2510,  ...,  0.9883, 10.4148, 32.8209],
       device='cuda:0', grad_fn=<MulBackward0>)
361
torch.Size([4096]) tensor([53.3952, 27.0610,  6.2488,  ...,  0.9884, 10.4162, 32.8278],
       device='cuda:0', grad_fn=<MulBackward0>)
362
torch.Size([4096]) tensor([53.3938, 27.0596,  6.2480,  ...,  0.9883, 10.4163, 32.8270],
       device='cuda:0', grad_fn=<MulBackward0>)
363
torch.Size([4096]) tensor([53.4058, 27.0606,  6.2464,  ...,  0.9869, 10.4175, 32.8337],
       device='cuda:0', grad_fn=<MulBackward0>)
364
torch.Size([4096]) tensor([53.4006, 27.0616,  6.2465,  ...,  0.9883, 10.4165, 32.8303],
       device='cuda:0', grad_fn=<MulBackward0>)
365
torch.Size([4096]) tensor([53.3929, 27.0618,  6.2490,  ...,  0.9884, 10.4152, 32.8250],
       device='cuda:0', grad_fn=<MulBackward0>)
366
torch.Size([4096]) tensor([53.3968, 27.0607,  6.2481,  ...,  0.9882, 10.4162, 32.8294],
       device='cuda:0', grad_fn=<MulBackward0>)
367
torch.Size([4096]) t

torch.Size([4096]) tensor([53.4381, 27.0760,  6.2403,  ...,  0.9852, 10.4176, 32.8457],
       device='cuda:0', grad_fn=<MulBackward0>)
421
torch.Size([4096]) tensor([53.4304, 27.0758,  6.2436,  ...,  0.9853, 10.4159, 32.8401],
       device='cuda:0', grad_fn=<MulBackward0>)
422
torch.Size([4096]) tensor([53.4447, 27.0795,  6.2423,  ...,  0.9852, 10.4180, 32.8495],
       device='cuda:0', grad_fn=<MulBackward0>)
423
torch.Size([4096]) tensor([53.4579, 27.0804,  6.2380,  ...,  0.9852, 10.4218, 32.8579],
       device='cuda:0', grad_fn=<MulBackward0>)
424
torch.Size([4096]) tensor([53.4403, 27.0775,  6.2427,  ...,  0.9851, 10.4215, 32.8453],
       device='cuda:0', grad_fn=<MulBackward0>)
425
torch.Size([4096]) tensor([53.4420, 27.0762,  6.2394,  ...,  0.9869, 10.4199, 32.8481],
       device='cuda:0', grad_fn=<MulBackward0>)
426
torch.Size([4096]) tensor([53.4400, 27.0746,  6.2400,  ...,  0.9852, 10.4214, 32.8495],
       device='cuda:0', grad_fn=<MulBackward0>)
427
torch.Size([4096]) t

torch.Size([4096]) tensor([53.4474, 27.0966,  6.2524,  ...,  0.9859, 10.4155, 32.8400],
       device='cuda:0', grad_fn=<MulBackward0>)
481
torch.Size([4096]) tensor([53.4632, 27.0997,  6.2508,  ...,  0.9850, 10.4181, 32.8484],
       device='cuda:0', grad_fn=<MulBackward0>)
482
torch.Size([4096]) tensor([53.4844, 27.0993,  6.2446,  ...,  0.9848, 10.4207, 32.8635],
       device='cuda:0', grad_fn=<MulBackward0>)
483
torch.Size([4096]) tensor([53.4938, 27.0994,  6.2435,  ...,  0.9829, 10.4221, 32.8688],
       device='cuda:0', grad_fn=<MulBackward0>)
484
torch.Size([4096]) tensor([53.4891, 27.1016,  6.2455,  ...,  0.9830, 10.4213, 32.8652],
       device='cuda:0', grad_fn=<MulBackward0>)
485
torch.Size([4096]) tensor([53.5078, 27.1050,  6.2387,  ...,  0.9867, 10.4223, 32.8785],
       device='cuda:0', grad_fn=<MulBackward0>)
486
torch.Size([4096]) tensor([53.4909, 27.1048,  6.2457,  ...,  0.9849, 10.4196, 32.8658],
       device='cuda:0', grad_fn=<MulBackward0>)
487
torch.Size([4096]) t

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.