In [1]:
import os
import pandas as pd 
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
from PIL import Image
import glob

from torchvision import models
import tqdm

import time
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision.transforms import Resize, Compose, ToPILImage, ToTensor
import pickle
import math

#from efficientnet_pytorch import EfficientNet

#from kornia.filters import SpatialGradient

import random
from torchvision.transforms import RandomCrop

In [11]:
patch_size = (512, 512)

In [12]:
class MonocularDepthDataset(Dataset):
    def __init__(self, df, transform=None, crop_size=patch_size):
        self.df = df
        self.transform = transform
        self.crop_size = crop_size
        

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image']
        depth_path = self.df.iloc[idx]['depth']

        image = Image.open(image_path) ##no rgb, takes grayscale
        depth = Image.open(depth_path)

        # randomly crop image and depth
        i, j, h, w = RandomCrop.get_params(image, output_size=(self.crop_size[0], self.crop_size[1]))
        image = TF.crop(image, i, j, h, w)
        depth = TF.crop(depth, i, j, h, w)

        if self.transform:
            image = self.transform(image)
            depth = self.transform(depth)

        return image, depth

In [13]:


def gradient_loss_fn(gen_frames, gt_frames, alpha=1):
    def gradient(x):
        # idea from tf.image.image_gradients(image)
        # https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/python/ops/image_ops_impl.py#L3441-L3512
        # x: (b,c,h,w), float32 or float64
        # dx, dy: (b,c,h,w)

        h_x = x.size()[-2]
        w_x = x.size()[-1]
        # gradient step=1
        left = x
        right = F.pad(x, [0, 1, 0, 0])[:, :, :, 1:]
        top = x
        bottom = F.pad(x, [0, 0, 0, 1])[:, :, 1:, :]

        # dx, dy = torch.abs(right - left), torch.abs(bottom - top)
        dx, dy = right - left, bottom - top 
        # dx will always have zeros in the last column, right-left
        # dy will always have zeros in the last row,    bottom-top
        dx[:, :, :, -1] = 0
        dy[:, :, -1, :] = 0

        return dx, dy

    # gradient
    gen_dx, gen_dy = gradient(gen_frames)
    gt_dx, gt_dy = gradient(gt_frames)
    #
    grad_diff_x = torch.abs(gt_dx - gen_dx)
    grad_diff_y = torch.abs(gt_dy - gen_dy)

    # condense into one tensor and avg
    return torch.mean(grad_diff_x ** alpha + grad_diff_y ** alpha)

class DepthEstimationLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(DepthEstimationLoss, self).__init__()
        self.alpha = alpha


    def forward(self, pred_depth, true_depth):
        pred_depth = torch.clamp(pred_depth, min=1e-8)
        true_depth = torch.clamp(true_depth, min=1e-8)

        # Scale-invariant MSE loss
        diff = torch.log(pred_depth) - torch.log(true_depth)
        mse_loss = torch.mean(diff**2)
        #scale_invariant_mse_loss = mse_loss - (self.alpha * (torch.sum(diff)**2)) / (true_depth.numel()**2)

    

        #gradient_loss = gradient_loss_fn(pred_depth,true_depth,alpha=self.alpha)

        #total_loss = (scale_invariant_mse_loss + gradient_loss)/2

        return (torch.sum((pred_depth - true_depth)**2))**0.5#scale_invariant_mse_loss#total_loss

In [14]:
def conv_relu_block(in_channel,out_channel,kernel,padding):
    return nn.Sequential(
            nn.Conv2d(in_channel,out_channel, kernel_size = kernel, padding=padding),
            nn.ReLU()) #nn.ReLU(inplace=True) #nn.Ge

In [15]:
class vanilla_unet(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        self.input_1 = conv_relu_block(3,3,3,1) ##grayscale inputs
        #self.input_2 = conv_relu_block(64, 64, 3, 1) #no extra channels

        self.base_model = models.resnet18(pretrained=True)
        self.base_layers = list(self.base_model.children())

        self.l0 = nn.Sequential(*self.base_layers[:3])
        self.U0_conv = conv_relu_block(64, 64, 1, 0)
        self.conv_up0 = conv_relu_block(64 + 256, 128, 3, 1)

        self.l1 = nn.Sequential(*self.base_layers[3:5])
        self.U1_conv = conv_relu_block(64, 64, 1, 0)
        self.conv_up1 = conv_relu_block(64 + 256, 256, 3, 1)

        self.l2 = self.base_layers[5]
        self.U2_conv = conv_relu_block(128, 128, 1, 0)
        self.conv_up2 = conv_relu_block(128 + 512, 256, 3, 1)

        self.l3 = self.base_layers[6]
        self.U3_conv = conv_relu_block(256, 256, 1, 0)
        self.conv_up3 = conv_relu_block(256 + 512, 512, 3, 1)

        self.l4 = self.base_layers[7]
        self.U4_conv = conv_relu_block(512, 512, 1, 0)

        self.conv_up4 = conv_relu_block(64 + 128, 64, 3, 1)

        self.out4 = nn.Conv2d(128, n_class, 1)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        x = torch.cat([x,x,x], axis = 1)
        x = self.input_1(x)
        
        #print(x.shape,'x')
         #concat on channel
        #x_one = self.input_2(x_one)
        block0 = self.l0(x)
        block1 = self.l1(block0)
        block2 = self.l2(block1)
        block3 = self.l3(block2)
        block4 = self.l4(block3)
        #print('b0: ', block0.shape)
        #print('b1: ', block1.shape)
        #print('b2: ', block2.shape)
        #print('b3: ', block3.shape)
        #print('b4: ', block4.shape)
        

        block4 = self.U4_conv(block4)
        x = self.upsample(block4)

        #print(block0.shape, block1.shape, block2.shape,block3.shape,block4.shape)
        block3 = self.U3_conv(block3)
        #print(x.shape, block3.shape)
        x = torch.cat([x, block3], axis=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        block2 = self.U2_conv(block2)
        #print('x shape: ', x.shape)
        #print('block2 precat: ', block2.shape)
        x = torch.cat([x, block2], axis=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        block1 = self.U1_conv(block1)
        x = torch.cat([x, block1], axis=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        block0 = self.U0_conv(block0)
        x = torch.cat([x, block0], axis=1)
        x = self.conv_up0(x)
        out4 = self.out4(x)

        out4_upsampled = F.interpolate(out4, scale_factor=2, mode='bilinear', align_corners=True)
        
        #relu = nn.ReLU()
        out = out4_upsampled#relu(out4_upsampled)
        
        
        return out



In [70]:
v = vanilla_unet(5)

In [71]:
x = torch.ones((1,1,patch_size[0],patch_size[1]))
#print(x.shape)
v.forward(x)

torch.Size([1, 1, 1024, 1024])
b0:  torch.Size([1, 64, 512, 512])
b1:  torch.Size([1, 64, 256, 256])
b2:  torch.Size([1, 128, 128, 128])
b3:  torch.Size([1, 256, 64, 64])
b4:  torch.Size([1, 512, 32, 32])
x shape:  torch.Size([1, 512, 128, 128])
block2 precat:  torch.Size([1, 128, 128, 128])


tensor([[[[ 0.2324,  0.2753,  0.3181,  ...,  0.2258,  0.1913,  0.1567],
          [ 0.1522,  0.1755,  0.1987,  ...,  0.1424,  0.1067,  0.0710],
          [ 0.0720,  0.0757,  0.0793,  ...,  0.0590,  0.0221, -0.0148],
          ...,
          [ 0.1539,  0.1555,  0.1572,  ...,  0.1608,  0.1143,  0.0678],
          [ 0.1300,  0.1384,  0.1468,  ...,  0.1555,  0.1270,  0.0986],
          [ 0.1061,  0.1213,  0.1364,  ...,  0.1502,  0.1398,  0.1293]],

         [[ 0.0386,  0.0779,  0.1171,  ...,  0.1012,  0.0768,  0.0524],
          [ 0.0123,  0.0582,  0.1040,  ...,  0.0794,  0.0827,  0.0860],
          [-0.0140,  0.0384,  0.0909,  ...,  0.0577,  0.0886,  0.1196],
          ...,
          [ 0.0289,  0.0508,  0.0728,  ...,  0.0127,  0.0350,  0.0573],
          [ 0.0424,  0.0579,  0.0733,  ...,  0.0497,  0.0531,  0.0564],
          [ 0.0560,  0.0649,  0.0737,  ...,  0.0867,  0.0711,  0.0555]],

         [[-0.0288,  0.0238,  0.0764,  ...,  0.0560,  0.0364,  0.0169],
          [-0.0580,  0.0121,  

In [16]:
#model = depth_model(num_classes=1).to('cuda')
#model = resunet(n_class=1).to('cuda')
model = vanilla_unet(n_class=1).to('cuda')

#model = effunet(n_class=1).to('cuda')



In [9]:
cd ../../krishna/project

/projectnb/cs585bp/krishna/project


In [17]:
# Set hyperparameters, dataset paths, and other configurations
batch_size = 8
learning_rate = 0.0005
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize(patch_size),
    transforms.ToTensor()
])

df = pd.read_csv('train.csv')
train_dataset = MonocularDepthDataset(df, transform = transform)
#val_dataset = MonocularDepthDataset(val_image_paths, val_depth_paths, transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12)
#val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


criterion = DepthEstimationLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in tqdm.tqdm_notebook(range(num_epochs)):
    #train_loss = train(model, train_dataloader, optimizer, criterion, device)
    
    model.train()
    running_loss = 0.0
    
    for images, depths in tqdm.tqdm_notebook(train_dataloader):
        images = images.to(device)
        depths = depths.to(device)
        
        mask = depths == 0
        f_img
        
        optimizer.zero_grad()

        outputs = model(images)
        
        loss = criterion(outputs[-1].float(), depths.float())
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
    train_loss = running_loss / len(train_dataloader)
    print('Training loss: ', train_loss)
    print(outputs)
    #val_loss = validate(model, val_dataloader, criterion, device)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for epoch in tqdm.tqdm_notebook(range(num_epochs)):


  0%|          | 0/10 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for images, depths in tqdm.tqdm_notebook(train_dataloader):


  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1523079.0421316964
tensor([[[[225.9700, 299.2883, 372.6066,  ..., 399.8113, 320.6970, 241.5827],
          [296.3044, 393.2942, 490.2840,  ..., 527.6954, 423.2629, 318.8303],
          [366.6387, 487.3001, 607.9614,  ..., 655.5797, 525.8288, 396.0779],
          ...,
          [402.1918, 532.0461, 661.9005,  ..., 627.2892, 504.0517, 380.8141],
          [324.9572, 429.4845, 534.0118,  ..., 504.5987, 405.5048, 306.4109],
          [247.7227, 326.9229, 406.1231,  ..., 381.9081, 306.9579, 232.0077]]],


        [[[230.1348, 304.7097, 379.2845,  ..., 397.3133, 319.5917, 241.8702],
          [302.3275, 401.2037, 500.0799,  ..., 524.8654, 422.1418, 319.4181],
          [374.5204, 497.6978, 620.8752,  ..., 652.4176, 524.6918, 396.9661],
          ...,
          [359.4037, 477.4421, 595.4806,  ..., 637.5113, 512.9872, 388.4631],
          [289.6819, 384.6567, 479.6316,  ..., 513.2006, 413.0503, 312.9001],
          [219.9601, 291.8713, 363.7825,  ..., 388.8900, 313.1135, 237.33

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1474671.5248325893
tensor([[[[231.9027, 301.5659, 371.2292,  ..., 380.2215, 309.5569, 238.8922],
          [299.2137, 389.5897, 479.9656,  ..., 492.2044, 400.6057, 309.0069],
          [366.5248, 477.6134, 588.7020,  ..., 604.1873, 491.6545, 379.1217],
          ...,
          [383.2927, 495.1938, 607.0948,  ..., 570.9784, 465.0889, 359.1995],
          [315.8765, 407.8130, 499.7496,  ..., 465.2250, 379.0339, 292.8427],
          [248.4602, 320.4323, 392.4044,  ..., 359.4717, 292.9788, 226.4859]]],


        [[[235.9197, 305.9644, 376.0091,  ..., 369.1043, 300.9142, 232.7242],
          [304.3523, 395.1953, 486.0382,  ..., 477.0979, 388.8683, 300.6386],
          [372.7849, 484.4261, 596.0673,  ..., 585.0914, 476.8222, 368.5530],
          ...,
          [345.5391, 449.8457, 554.1523,  ..., 590.8282, 481.7087, 372.5893],
          [281.0870, 365.7562, 450.4255,  ..., 482.3787, 393.3954, 304.4121],
          [216.6349, 281.6668, 346.6986,  ..., 373.9292, 305.0820, 236.23

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1467518.6082589286
tensor([[[[344.5898, 437.2983, 530.0067,  ..., 528.1404, 435.0603, 341.9802],
          [435.0442, 553.0107, 670.9772,  ..., 670.5991, 552.5226, 434.4462],
          [525.4985, 668.7231, 811.9476,  ..., 813.0578, 669.9850, 526.9122],
          ...,
          [488.5919, 622.8073, 757.0226,  ..., 749.3574, 618.4691, 487.5807],
          [398.8768, 508.2958, 617.7148,  ..., 617.5297, 509.7918, 402.0539],
          [309.1617, 393.7844, 478.4072,  ..., 485.7020, 401.1145, 316.5271]]],


        [[[344.4405, 439.4238, 534.4072,  ..., 510.6799, 420.9977, 331.3156],
          [435.6964, 556.7875, 677.8787,  ..., 645.9640, 532.5869, 419.2097],
          [526.9523, 674.1512, 821.3502,  ..., 781.2481, 644.1760, 507.1039],
          ...,
          [491.3010, 625.2368, 759.1727,  ..., 743.5405, 614.4948, 485.4492],
          [403.1658, 512.7901, 622.4144,  ..., 613.4611, 507.1001, 400.7392],
          [315.0306, 400.3434, 485.6561,  ..., 483.3817, 399.7054, 316.02

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1463779.0066964286
tensor([[[[306.9858, 388.9007, 470.8156,  ..., 480.5141, 396.9550, 313.3960],
          [385.1759, 488.5612, 591.9464,  ..., 605.3864, 500.1345, 394.8826],
          [463.3661, 588.2217, 713.0773,  ..., 730.2586, 603.3140, 476.3693],
          ...,
          [490.1723, 621.2407, 752.3091,  ..., 748.0969, 619.2391, 490.3813],
          [405.6099, 513.9461, 622.2823,  ..., 616.6927, 510.7030, 404.7133],
          [321.0475, 406.6516, 492.2556,  ..., 485.2885, 402.1669, 319.0453]]],


        [[[360.3994, 458.9185, 557.4376,  ..., 436.1145, 359.6198, 283.1252],
          [452.9656, 576.6262, 700.2866,  ..., 548.2275, 452.0008, 355.7741],
          [545.5317, 694.3337, 843.1357,  ..., 660.3405, 544.3817, 428.4230],
          ...,
          [496.9497, 630.9760, 765.0023,  ..., 840.9426, 695.8099, 550.6772],
          [409.7520, 520.2670, 630.7820,  ..., 695.8018, 575.5226, 455.2434],
          [322.5544, 409.5580, 496.5616,  ..., 550.6611, 455.2353, 359.80

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1458216.4210379464
tensor([[[[388.8443, 485.8908, 582.9374,  ..., 535.7897, 443.3846, 350.9795],
          [482.5026, 603.0693, 723.6360,  ..., 666.7281, 552.2145, 437.7008],
          [576.1608, 720.2477, 864.3346,  ..., 797.6667, 661.0444, 524.4221],
          ...,
          [519.9626, 655.4701, 790.9776,  ..., 804.8029, 671.1104, 537.4180],
          [430.2531, 542.1044, 653.9557,  ..., 669.9803, 558.5977, 447.2151],
          [340.5435, 428.7386, 516.9337,  ..., 535.1578, 446.0851, 357.0122]]],


        [[[373.7723, 469.3964, 565.0204,  ..., 553.0355, 460.9529, 368.8703],
          [465.5296, 584.9773, 704.4249,  ..., 685.6057, 571.8902, 458.1747],
          [557.2869, 700.5582, 843.8293,  ..., 818.1758, 682.8275, 547.4792],
          ...,
          [569.8797, 704.7126, 839.5454,  ..., 803.2651, 670.7146, 538.1641],
          [477.3312, 590.3796, 703.4279,  ..., 668.4232, 558.1406, 447.8582],
          [384.7827, 476.0465, 567.3103,  ..., 533.5812, 445.5667, 357.55

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1480732.0301339286
tensor([[[[462.1521, 566.7400, 671.3279,  ..., 611.2800, 513.0312, 414.7823],
          [562.5982, 689.0969, 815.5957,  ..., 741.1384, 622.9814, 504.8244],
          [663.0443, 811.4539, 959.8635,  ..., 870.9968, 732.9316, 594.8665],
          ...,
          [665.6561, 807.9935, 950.3309,  ..., 853.2244, 724.0930, 594.9614],
          [561.4126, 682.1464, 802.8802,  ..., 719.5746, 610.2693, 500.9641],
          [457.1691, 556.2993, 655.4296,  ..., 585.9247, 496.4458, 406.9668]]],


        [[[467.1220, 573.3644, 679.6069,  ..., 634.4144, 534.9733, 435.5323],
          [569.4924, 698.4490, 827.4056,  ..., 766.2242, 647.6335, 529.0428],
          [671.8627, 823.5335, 975.2043,  ..., 898.0340, 760.2937, 622.5533],
          ...,
          [600.7769, 734.7234, 868.6700,  ..., 855.2563, 725.0364, 594.8165],
          [501.4323, 613.3328, 725.2334,  ..., 721.0593, 610.8246, 500.5899],
          [402.0877, 491.9422, 581.7968,  ..., 586.8623, 496.6128, 406.36

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1467009.1509486607
tensor([[[[367.5177, 453.5918, 539.6659,  ..., 528.4523, 443.1931, 357.9339],
          [451.6865, 557.2608, 662.8350,  ..., 647.5078, 543.5438, 439.5798],
          [535.8553, 660.9297, 786.0042,  ..., 766.5632, 643.8945, 521.2258],
          ...,
          [511.0103, 627.9118, 744.8134,  ..., 707.9538, 594.2532, 480.5525],
          [427.0215, 524.9450, 622.8686,  ..., 592.2486, 496.9713, 401.6940],
          [343.0327, 421.9782, 500.9238,  ..., 476.5433, 399.6894, 322.8355]]],


        [[[388.1499, 478.6607, 569.1715,  ..., 530.0013, 443.3775, 356.7536],
          [475.7424, 586.4191, 697.0958,  ..., 649.4852, 543.5722, 437.6592],
          [563.3348, 694.1774, 825.0201,  ..., 768.9691, 643.7669, 518.5647],
          ...,
          [567.3347, 698.9513, 830.5681,  ..., 783.9352, 660.5084, 537.0815],
          [476.5744, 587.7206, 698.8669,  ..., 656.6927, 553.4061, 450.1196],
          [385.8141, 476.4899, 567.1657,  ..., 529.4503, 446.3040, 363.15

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1465814.258091518
tensor([[[[394.7250, 481.3030, 567.8811,  ..., 546.9627, 460.9466, 374.9305],
          [478.3272, 582.6714, 687.0155,  ..., 662.1572, 558.5892, 455.0210],
          [561.9295, 684.0397, 806.1500,  ..., 777.3518, 656.2317, 535.1116],
          ...,
          [524.8081, 637.8411, 750.8741,  ..., 788.5690, 669.7343, 550.8996],
          [440.8127, 536.1132, 631.4137,  ..., 665.0012, 564.3863, 463.7714],
          [356.8173, 434.3853, 511.9532,  ..., 541.4335, 459.0383, 376.6431]]],


        [[[394.2775, 481.0466, 567.8157,  ..., 607.6087, 514.9890, 422.3692],
          [477.8984, 582.5421, 687.1858,  ..., 733.4951, 623.4977, 513.5003],
          [561.5193, 684.0376, 806.5559,  ..., 859.3815, 732.0065, 604.6315],
          ...,
          [616.2441, 747.1074, 877.9706,  ..., 814.4071, 692.3724, 570.3377],
          [522.8502, 635.0223, 747.1943,  ..., 691.0732, 586.9860, 482.8987],
          [429.4564, 522.9371, 616.4179,  ..., 567.7394, 481.5996, 395.459

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1473001.5968191964
tensor([[[[424.9778, 515.8203, 606.6627,  ..., 568.1049, 482.4688, 396.8328],
          [513.2971, 622.1977, 731.0983,  ..., 685.2303, 583.3552, 481.4799],
          [601.6164, 728.5751, 855.5338,  ..., 802.3558, 684.2415, 566.1271],
          ...,
          [525.1988, 634.2861, 743.3735,  ..., 763.6105, 651.1446, 538.6787],
          [442.4278, 534.7995, 627.1713,  ..., 646.3289, 550.7292, 455.1296],
          [359.6567, 435.3129, 510.9691,  ..., 529.0474, 450.3139, 371.5805]]],


        [[[406.9640, 493.1806, 579.3972,  ..., 524.1158, 444.2903, 364.4648],
          [492.2733, 595.6853, 699.0973,  ..., 629.3607, 534.4484, 439.5363],
          [577.5826, 698.1900, 818.7974,  ..., 734.6055, 624.6066, 514.6077],
          ...,
          [620.3486, 758.2371, 896.1256,  ..., 720.1153, 612.7184, 505.3216],
          [525.8380, 644.1166, 762.3951,  ..., 609.7502, 518.4512, 427.1523],
          [431.3274, 529.9960, 628.6646,  ..., 499.3850, 424.1841, 348.98

  0%|          | 0/224 [00:00<?, ?it/s]

Training loss:  1475661.4481026786
tensor([[[[ 519.4010,  628.2874,  737.1738,  ...,  689.0264,  587.6663,
            486.3061],
          [ 623.0278,  752.8343,  882.6408,  ...,  826.3239,  706.7734,
            587.2228],
          [ 726.6545,  877.3811, 1028.1078,  ...,  963.6213,  825.8804,
            688.1396],
          ...,
          [ 665.1313,  798.4329,  931.7346,  ...,  895.6385,  764.7383,
            633.8381],
          [ 562.1439,  675.7933,  789.4428,  ...,  760.2930,  648.4470,
            536.6010],
          [ 459.1565,  553.1538,  647.1511,  ...,  624.9475,  532.1558,
            439.3640]]],


        [[[ 498.7350,  603.8744,  709.0139,  ...,  700.3795,  596.3596,
            492.3397],
          [ 601.0522,  725.8787,  850.7052,  ...,  842.1880,  718.9941,
            595.8004],
          [ 703.3694,  847.8830,  992.3966,  ...,  983.9965,  841.6287,
            699.2610],
          ...,
          [ 653.9136,  780.7143,  907.5150,  ...,  953.2515,  816.7070,
    

In [None]:
cd ../../nkono/IVC_MDE

In [None]:
torch.save(model.state_dict(), 'good_small_model.pt')

In [None]:
pwd