In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms, models # add models to the list
from torchvision.utils import make_grid
import os

from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
from sklearn import preprocessing

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import os
from PIL import Image
from IPython.display import display
import open3d as o3d

# Filter harmless warnings
import warnings
warnings.filterwarnings("ignore")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


# Wczytanie pliku .csv
Zaczynamy od wczytania pliku csv w którym są zapisane ścieżki do poszczególnych zdjęć oraz współrzędne punktów

In [2]:
df = pd.read_csv('./files/data.csv') 
df_INPUT_DEPTH = df[['depth_img_I', 'depth_img_II']]
df_INPUT_RGB = df[['rgb_img_I', 'rgb_img_II']]
df_OUTPUT = df[['x1','y1','z1','x2','y2','z2']]

In [3]:
# torch.cuda.is_available()

# Funkcje
Stworzenie fukncji, które tłumaczą pliki na język matematyczny

In [4]:
TrasformData = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
ResizeData = transforms.Resize(480)


def load_rgb_dataset(top_dir='./files/images/rgb_img_I'):
    images_dataset = []
    for root, dirs, files in os.walk(top_dir):
        for name in files:
            # print(os.path.join(root, name))
            img = np.array(Image.open(os.path.join(root, name)))
            np.array(images_dataset.append(img))
    return np.array(images_dataset)

def load_depth_dataset(depth_series, row=0):    
    depth_dataset = []
    for n in range(depth_series.shape[0]):
        depth = np.array(np.load(depth_series.iloc[n,row]))
        np.array(depth_dataset.append(depth))
    
    return np.array(depth_dataset)

def NormImage(img,div_val,reshape=False,transform=False):
    if reshape == True:
        img = img.reshape((img.shape[0], img.shape[1], img.shape[2], 1))
                    
    img = img.transpose(0,3,1,2)
    img = torch.FloatTensor(img).div(div_val)
    if transform == True:
        img = TrasformData(img)
    
    return img

# Załadowanie danych
Ładujemy dane do zmiennych a następnie odpowiednio przekształcamy

In [5]:
RGBimg_begin = load_rgb_dataset('./files/images/rgb_img_I')
RGBimg_end = load_rgb_dataset('./files/images/rgb_img_II')

In [6]:
DEPTHimg_begin = load_depth_dataset(df_INPUT_DEPTH,0)
DEPTHimg_end = load_depth_dataset(df_INPUT_DEPTH,1)
# print(DEPTHimg_end[30][400])

In [7]:
# Taking first rgb image
rgb_in = NormImage(RGBimg_begin,255,transform=True)

# Taking depth beginning state of the movement
depthBeg_in = NormImage(DEPTHimg_begin,65535,reshape=True)

# Taking depth end state of the movement
depthEnd_in = NormImage(DEPTHimg_end,65535,reshape=True)

# Taking depth difference between movements
depthDiff_in = NormImage(abs(DEPTHimg_begin - DEPTHimg_end),65535,reshape=True)

# Taking outputs
axis_out = df_OUTPUT.values
axis_out = torch.Tensor(axis_out)

# Connected input for model
DDD_in = torch.cat((depthBeg_in, depthEnd_in, depthDiff_in),axis=1)
# DDD_in = TrasformData(DDD_in)
print(DDD_in.shape)

RGBD_input = ResizeData(torch.cat((rgb_in, DDD_in),axis=1))
print(RGBD_input.shape)
# RGBD_input = DDD_in

torch.Size([155, 3, 480, 640])
torch.Size([155, 6, 480, 640])


In [8]:
# plt.imshow(RGBD_input[30][0].numpy())

# RGB + D 
Na wejście do modelu zostanie podany tensor zawieający kombinację RGB + D 

In [9]:
# rgb_inTensor = torch.tensor(rgb_in.astype(float), dtype=torch.float)
# depth_inTensor = torch.tensor(depth_in.astype(float), dtype=torch.float)
# axis_outTensor = torch.tensor(axis_out.astype(float), dtype=torch.float)
X_train, X_validation, y_train, y_validation = train_test_split(RGBD_input, axis_out, test_size=0.1)

In [10]:
X_validation.shape

torch.Size([16, 6, 480, 640])

In [11]:
AoRD_trainDataset = TensorDataset(X_train, y_train)
AoRD_validationDataset = TensorDataset(X_validation, y_validation)

In [12]:
train_loader = DataLoader(AoRD_trainDataset, batch_size=5, shuffle=True)
validation_loader = DataLoader(AoRD_validationDataset, batch_size=1, shuffle=True)

In [13]:
# plt.imshow(rgb_in[34])

In [14]:
# ResNetModel = models.resnet101(pretrained=False)
# ResNetModel

In [15]:
# ResNetModel.conv1 = nn.Conv2d(in_channels=4, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# Wyciąganie pojedyńczego elementu z batcha
Można zrobić to na kilka sposobów, ale ten jest najszybszy

In [16]:
for b, (X_train, y_train) in enumerate(train_loader):
    pass

In [17]:
X_train.dtype

torch.float32

# Stworzenie modelu
Nazwałem model AoRNet od angielsiego **A**xis **o**f **R**rotation oraz od nazwy modelu matki Res**Net**`u 

In [16]:
class AoRNet(nn.Module):
    def __init__(self,pretrained=False ,input_channels=6, output_size=6):
        super().__init__()
        self.resnet50 = models.resnet50(pretrained=pretrained)
        self.resnet50.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet50.fc = nn.Linear(in_features=2048, out_features=output_size, bias=True)
    
    def forward(self, X):
        return self.resnet50(X)

In [17]:
Model = AoRNet().cuda()

In [18]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(Model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min')

In [19]:
def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{item:>8}')
    print(f'________\n{sum(params):>8}')
# count_parameters(My_Model)

In [22]:
epochs = 250

train_losses = []
validation_losses = []

for i in range(epochs):
    # Run the training batches
    for b, (X_train, y_train) in enumerate(train_loader):
        b+=1
        # Apply the model
        y_pred = Model(X_train.cuda())
        loss = criterion(y_pred, y_train.cuda())
#         torch.cuda.empty_cache()
        # Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print interim results
        if b%1 == 0:
            print(f'epoch: {i+1:2}  batch: {b}  loss: {loss.item():10.8f}')
    
    train_losses.append(loss.cpu().detach().numpy())
    scheduler.step(loss)
    
    # Run the validationing batches
    with torch.no_grad():
        for b, (X_validation, y_validation) in enumerate(validation_loader):
            # Apply the model
            y_val = Model(X_validation.cuda())
    loss = criterion(y_val, y_validation.cuda())
    validation_losses.append(loss.cpu().detach().numpy())
#     validation_correct.append(tst_corr)

epoch:  1  batch: 1  loss: 2.77987123
epoch:  1  batch: 2  loss: 2.13590026
epoch:  1  batch: 3  loss: 0.40691528
epoch:  1  batch: 4  loss: 0.40870586
epoch:  1  batch: 5  loss: 0.25745320
epoch:  1  batch: 6  loss: 0.13410991
epoch:  1  batch: 7  loss: 0.08647025
epoch:  1  batch: 8  loss: 0.09429970
epoch:  1  batch: 9  loss: 0.05336744
epoch:  1  batch: 10  loss: 0.27108589
epoch:  1  batch: 11  loss: 0.11276813
epoch:  1  batch: 12  loss: 0.07010373
epoch:  1  batch: 13  loss: 0.08267649
epoch:  1  batch: 14  loss: 0.07709704
epoch:  1  batch: 15  loss: 0.09936238
epoch:  1  batch: 16  loss: 0.06645229
epoch:  1  batch: 17  loss: 0.07738516
epoch:  1  batch: 18  loss: 0.11031090
epoch:  1  batch: 19  loss: 0.07467216
epoch:  1  batch: 20  loss: 0.12405441
epoch:  1  batch: 21  loss: 0.06330313
epoch:  1  batch: 22  loss: 0.03751905
epoch:  1  batch: 23  loss: 0.05915470
epoch:  1  batch: 24  loss: 0.08281343
epoch:  1  batch: 25  loss: 0.02921475
epoch:  1  batch: 26  loss: 0.0185

epoch:  8  batch: 17  loss: 0.03073475
epoch:  8  batch: 18  loss: 0.02411196
epoch:  8  batch: 19  loss: 0.01096667
epoch:  8  batch: 20  loss: 0.02920349
epoch:  8  batch: 21  loss: 0.03204548
epoch:  8  batch: 22  loss: 0.06185277
epoch:  8  batch: 23  loss: 0.01997651
epoch:  8  batch: 24  loss: 0.04428843
epoch:  8  batch: 25  loss: 0.03709837
epoch:  8  batch: 26  loss: 0.01388501
epoch:  8  batch: 27  loss: 0.02052389
epoch:  8  batch: 28  loss: 0.01125665
epoch:  9  batch: 1  loss: 0.03055094
epoch:  9  batch: 2  loss: 0.04751106
epoch:  9  batch: 3  loss: 0.03224637
epoch:  9  batch: 4  loss: 0.01474689
epoch:  9  batch: 5  loss: 0.07403738
epoch:  9  batch: 6  loss: 0.02731878
epoch:  9  batch: 7  loss: 0.02051228
epoch:  9  batch: 8  loss: 0.05491201
epoch:  9  batch: 9  loss: 0.03595900
epoch:  9  batch: 10  loss: 0.03837840
epoch:  9  batch: 11  loss: 0.01810339
epoch:  9  batch: 12  loss: 0.04211831
epoch:  9  batch: 13  loss: 0.06902790
epoch:  9  batch: 14  loss: 0.0351

epoch: 16  batch: 5  loss: 0.04501213
epoch: 16  batch: 6  loss: 0.06077906
epoch: 16  batch: 7  loss: 0.07162433
epoch: 16  batch: 8  loss: 0.17212296
epoch: 16  batch: 9  loss: 0.13831285
epoch: 16  batch: 10  loss: 0.04777497
epoch: 16  batch: 11  loss: 0.03923805
epoch: 16  batch: 12  loss: 0.17826211
epoch: 16  batch: 13  loss: 0.04901672
epoch: 16  batch: 14  loss: 0.04651303
epoch: 16  batch: 15  loss: 0.02177637
epoch: 16  batch: 16  loss: 0.05341881
epoch: 16  batch: 17  loss: 0.05477630
epoch: 16  batch: 18  loss: 0.03114882
epoch: 16  batch: 19  loss: 0.03101228
epoch: 16  batch: 20  loss: 0.03951875
epoch: 16  batch: 21  loss: 0.01806356
epoch: 16  batch: 22  loss: 0.03335533
epoch: 16  batch: 23  loss: 0.03561433
epoch: 16  batch: 24  loss: 0.02659368
epoch: 16  batch: 25  loss: 0.01467277
epoch: 16  batch: 26  loss: 0.04990435
epoch: 16  batch: 27  loss: 0.01887333
epoch: 16  batch: 28  loss: 0.01604225
epoch: 17  batch: 1  loss: 0.02350834
epoch: 17  batch: 2  loss: 0.05

epoch: 23  batch: 21  loss: 0.01718898
epoch: 23  batch: 22  loss: 0.01391705
epoch: 23  batch: 23  loss: 0.00920560
epoch: 23  batch: 24  loss: 0.00292695
epoch: 23  batch: 25  loss: 0.00448962
epoch: 23  batch: 26  loss: 0.02371311
epoch: 23  batch: 27  loss: 0.00233507
epoch: 23  batch: 28  loss: 0.02494265
epoch: 24  batch: 1  loss: 0.01020098
epoch: 24  batch: 2  loss: 0.01257771
epoch: 24  batch: 3  loss: 0.01682682
epoch: 24  batch: 4  loss: 0.00383343
epoch: 24  batch: 5  loss: 0.06505097
epoch: 24  batch: 6  loss: 0.18674730
epoch: 24  batch: 7  loss: 0.00491677
epoch: 24  batch: 8  loss: 0.00544609
epoch: 24  batch: 9  loss: 0.00696608
epoch: 24  batch: 10  loss: 0.00972205
epoch: 24  batch: 11  loss: 0.00665667
epoch: 24  batch: 12  loss: 0.01778626
epoch: 24  batch: 13  loss: 0.00567752
epoch: 24  batch: 14  loss: 0.02518505
epoch: 24  batch: 15  loss: 0.01582591
epoch: 24  batch: 16  loss: 0.02618512
epoch: 24  batch: 17  loss: 0.02461114
epoch: 24  batch: 18  loss: 0.0086

epoch: 31  batch: 9  loss: 0.01424555
epoch: 31  batch: 10  loss: 0.00309240
epoch: 31  batch: 11  loss: 0.01234553
epoch: 31  batch: 12  loss: 0.00573700
epoch: 31  batch: 13  loss: 0.01841266
epoch: 31  batch: 14  loss: 0.03725782
epoch: 31  batch: 15  loss: 0.00283672
epoch: 31  batch: 16  loss: 0.00602878
epoch: 31  batch: 17  loss: 0.00553261
epoch: 31  batch: 18  loss: 0.00521779
epoch: 31  batch: 19  loss: 0.01667085
epoch: 31  batch: 20  loss: 0.01557848
epoch: 31  batch: 21  loss: 0.00514405
epoch: 31  batch: 22  loss: 0.00890283
epoch: 31  batch: 23  loss: 0.01653669
epoch: 31  batch: 24  loss: 0.00461402
epoch: 31  batch: 25  loss: 0.00763331
epoch: 31  batch: 26  loss: 0.03503365
epoch: 31  batch: 27  loss: 0.00541289
epoch: 31  batch: 28  loss: 0.01353684
epoch: 32  batch: 1  loss: 0.01109902
epoch: 32  batch: 2  loss: 0.01490990
epoch: 32  batch: 3  loss: 0.01312821
epoch: 32  batch: 4  loss: 0.00866296
epoch: 32  batch: 5  loss: 0.00759000
epoch: 32  batch: 6  loss: 0.00

epoch: 38  batch: 25  loss: 0.00736521
epoch: 38  batch: 26  loss: 0.01137324
epoch: 38  batch: 27  loss: 0.00868015
epoch: 38  batch: 28  loss: 0.00305318
epoch: 39  batch: 1  loss: 0.00505351
epoch: 39  batch: 2  loss: 0.02092802
epoch: 39  batch: 3  loss: 0.02913116
epoch: 39  batch: 4  loss: 0.00626676
epoch: 39  batch: 5  loss: 0.01370944
epoch: 39  batch: 6  loss: 0.00491000
epoch: 39  batch: 7  loss: 0.00393745
epoch: 39  batch: 8  loss: 0.00495561
epoch: 39  batch: 9  loss: 0.01020581
epoch: 39  batch: 10  loss: 0.00621899
epoch: 39  batch: 11  loss: 0.01037931
epoch: 39  batch: 12  loss: 0.00149806
epoch: 39  batch: 13  loss: 0.00272775
epoch: 39  batch: 14  loss: 0.01393400
epoch: 39  batch: 15  loss: 0.00394950
epoch: 39  batch: 16  loss: 0.00840202
epoch: 39  batch: 17  loss: 0.00506222
epoch: 39  batch: 18  loss: 0.00524521
epoch: 39  batch: 19  loss: 0.00291195
epoch: 39  batch: 20  loss: 0.00713525
epoch: 39  batch: 21  loss: 0.00852131
epoch: 39  batch: 22  loss: 0.0044

epoch: 46  batch: 13  loss: 0.00296361
epoch: 46  batch: 14  loss: 0.00683887
epoch: 46  batch: 15  loss: 0.01443273
epoch: 46  batch: 16  loss: 0.00678823
epoch: 46  batch: 17  loss: 0.00447670
epoch: 46  batch: 18  loss: 0.00724740
epoch: 46  batch: 19  loss: 0.00672037
epoch: 46  batch: 20  loss: 0.00300844
epoch: 46  batch: 21  loss: 0.00920757
epoch: 46  batch: 22  loss: 0.00402649
epoch: 46  batch: 23  loss: 0.00405417
epoch: 46  batch: 24  loss: 0.01549707
epoch: 46  batch: 25  loss: 0.01371940
epoch: 46  batch: 26  loss: 0.00350746
epoch: 46  batch: 27  loss: 0.00559805
epoch: 46  batch: 28  loss: 0.00349344
epoch: 47  batch: 1  loss: 0.00512720
epoch: 47  batch: 2  loss: 0.00718183
epoch: 47  batch: 3  loss: 0.00668633
epoch: 47  batch: 4  loss: 0.00673876
epoch: 47  batch: 5  loss: 0.00597318
epoch: 47  batch: 6  loss: 0.01492998
epoch: 47  batch: 7  loss: 0.01493251
epoch: 47  batch: 8  loss: 0.00326443
epoch: 47  batch: 9  loss: 0.01064783
epoch: 47  batch: 10  loss: 0.0105

epoch: 54  batch: 1  loss: 0.00398049
epoch: 54  batch: 2  loss: 0.00191866
epoch: 54  batch: 3  loss: 0.00715553
epoch: 54  batch: 4  loss: 0.00243594
epoch: 54  batch: 5  loss: 0.00300573
epoch: 54  batch: 6  loss: 0.00142835
epoch: 54  batch: 7  loss: 0.00655761
epoch: 54  batch: 8  loss: 0.00623864
epoch: 54  batch: 9  loss: 0.00248534
epoch: 54  batch: 10  loss: 0.00213433
epoch: 54  batch: 11  loss: 0.00440299
epoch: 54  batch: 12  loss: 0.01519646
epoch: 54  batch: 13  loss: 0.00842865
epoch: 54  batch: 14  loss: 0.00407317
epoch: 54  batch: 15  loss: 0.00942051
epoch: 54  batch: 16  loss: 0.00514378
epoch: 54  batch: 17  loss: 0.00293544
epoch: 54  batch: 18  loss: 0.00362836
epoch: 54  batch: 19  loss: 0.00990633
epoch: 54  batch: 20  loss: 0.00874521
epoch: 54  batch: 21  loss: 0.00412256
epoch: 54  batch: 22  loss: 0.00194791
epoch: 54  batch: 23  loss: 0.00531538
epoch: 54  batch: 24  loss: 0.00127961
epoch: 54  batch: 25  loss: 0.02313403
epoch: 54  batch: 26  loss: 0.0046

epoch: 61  batch: 17  loss: 0.01061077
epoch: 61  batch: 18  loss: 0.00802185
epoch: 61  batch: 19  loss: 0.00217963
epoch: 61  batch: 20  loss: 0.00576146
epoch: 61  batch: 21  loss: 0.04058519
epoch: 61  batch: 22  loss: 0.00211691
epoch: 61  batch: 23  loss: 0.00576417
epoch: 61  batch: 24  loss: 0.01008681
epoch: 61  batch: 25  loss: 0.00276529
epoch: 61  batch: 26  loss: 0.00352122
epoch: 61  batch: 27  loss: 0.01155243
epoch: 61  batch: 28  loss: 0.00311443
epoch: 62  batch: 1  loss: 0.00169339
epoch: 62  batch: 2  loss: 0.00327089
epoch: 62  batch: 3  loss: 0.00526732
epoch: 62  batch: 4  loss: 0.00796407
epoch: 62  batch: 5  loss: 0.00633893
epoch: 62  batch: 6  loss: 0.00451753
epoch: 62  batch: 7  loss: 0.00366386
epoch: 62  batch: 8  loss: 0.00385873
epoch: 62  batch: 9  loss: 0.02223451
epoch: 62  batch: 10  loss: 0.00328600
epoch: 62  batch: 11  loss: 0.00132230
epoch: 62  batch: 12  loss: 0.00257692
epoch: 62  batch: 13  loss: 0.00152246
epoch: 62  batch: 14  loss: 0.0035

epoch: 69  batch: 5  loss: 0.00216015
epoch: 69  batch: 6  loss: 0.00341273
epoch: 69  batch: 7  loss: 0.01152800
epoch: 69  batch: 8  loss: 0.00422645
epoch: 69  batch: 9  loss: 0.00334851
epoch: 69  batch: 10  loss: 0.00578682
epoch: 69  batch: 11  loss: 0.00722482
epoch: 69  batch: 12  loss: 0.00324116
epoch: 69  batch: 13  loss: 0.00196473
epoch: 69  batch: 14  loss: 0.02132596
epoch: 69  batch: 15  loss: 0.00103741
epoch: 69  batch: 16  loss: 0.00445250
epoch: 69  batch: 17  loss: 0.00136278
epoch: 69  batch: 18  loss: 0.00968396
epoch: 69  batch: 19  loss: 0.00289380
epoch: 69  batch: 20  loss: 0.00533204
epoch: 69  batch: 21  loss: 0.00400736
epoch: 69  batch: 22  loss: 0.01149686
epoch: 69  batch: 23  loss: 0.00431862
epoch: 69  batch: 24  loss: 0.00658526
epoch: 69  batch: 25  loss: 0.00271594
epoch: 69  batch: 26  loss: 0.01567597
epoch: 69  batch: 27  loss: 0.00178466
epoch: 69  batch: 28  loss: 0.00719803
epoch: 70  batch: 1  loss: 0.00327727
epoch: 70  batch: 2  loss: 0.00

epoch: 76  batch: 21  loss: 0.00586104
epoch: 76  batch: 22  loss: 0.00187598
epoch: 76  batch: 23  loss: 0.01188638
epoch: 76  batch: 24  loss: 0.00440509
epoch: 76  batch: 25  loss: 0.00401269
epoch: 76  batch: 26  loss: 0.02499672
epoch: 76  batch: 27  loss: 0.00391500
epoch: 76  batch: 28  loss: 0.00526127
epoch: 77  batch: 1  loss: 0.00603982
epoch: 77  batch: 2  loss: 0.00342333
epoch: 77  batch: 3  loss: 0.00234981
epoch: 77  batch: 4  loss: 0.00461140
epoch: 77  batch: 5  loss: 0.00546116
epoch: 77  batch: 6  loss: 0.00757484
epoch: 77  batch: 7  loss: 0.00384418
epoch: 77  batch: 8  loss: 0.00645573
epoch: 77  batch: 9  loss: 0.00179800
epoch: 77  batch: 10  loss: 0.00272523
epoch: 77  batch: 11  loss: 0.00730286
epoch: 77  batch: 12  loss: 0.00183278
epoch: 77  batch: 13  loss: 0.00162431
epoch: 77  batch: 14  loss: 0.00261608
epoch: 77  batch: 15  loss: 0.00272264
epoch: 77  batch: 16  loss: 0.00284028
epoch: 77  batch: 17  loss: 0.00351240
epoch: 77  batch: 18  loss: 0.0038

epoch: 84  batch: 9  loss: 0.01675901
epoch: 84  batch: 10  loss: 0.00144385
epoch: 84  batch: 11  loss: 0.00231220
epoch: 84  batch: 12  loss: 0.00109607
epoch: 84  batch: 13  loss: 0.00465169
epoch: 84  batch: 14  loss: 0.00806170
epoch: 84  batch: 15  loss: 0.00423431
epoch: 84  batch: 16  loss: 0.00932396
epoch: 84  batch: 17  loss: 0.00563385
epoch: 84  batch: 18  loss: 0.00143258
epoch: 84  batch: 19  loss: 0.00242556
epoch: 84  batch: 20  loss: 0.00080230
epoch: 84  batch: 21  loss: 0.00460717
epoch: 84  batch: 22  loss: 0.00095662
epoch: 84  batch: 23  loss: 0.00164855
epoch: 84  batch: 24  loss: 0.00279567
epoch: 84  batch: 25  loss: 0.00134991
epoch: 84  batch: 26  loss: 0.00890042
epoch: 84  batch: 27  loss: 0.00192919
epoch: 84  batch: 28  loss: 0.00141923
epoch: 85  batch: 1  loss: 0.00109542
epoch: 85  batch: 2  loss: 0.00912761
epoch: 85  batch: 3  loss: 0.00355384
epoch: 85  batch: 4  loss: 0.00413050
epoch: 85  batch: 5  loss: 0.00436473
epoch: 85  batch: 6  loss: 0.00

epoch: 91  batch: 25  loss: 0.01920279
epoch: 91  batch: 26  loss: 0.01138839
epoch: 91  batch: 27  loss: 0.00696591
epoch: 91  batch: 28  loss: 0.00227154
epoch: 92  batch: 1  loss: 0.00281146
epoch: 92  batch: 2  loss: 0.00086020
epoch: 92  batch: 3  loss: 0.01144700
epoch: 92  batch: 4  loss: 0.00271062
epoch: 92  batch: 5  loss: 0.00101651
epoch: 92  batch: 6  loss: 0.00236637
epoch: 92  batch: 7  loss: 0.01052010
epoch: 92  batch: 8  loss: 0.01326492
epoch: 92  batch: 9  loss: 0.00609394
epoch: 92  batch: 10  loss: 0.00219254
epoch: 92  batch: 11  loss: 0.00266868
epoch: 92  batch: 12  loss: 0.00171614
epoch: 92  batch: 13  loss: 0.00136480
epoch: 92  batch: 14  loss: 0.00396087
epoch: 92  batch: 15  loss: 0.00438533
epoch: 92  batch: 16  loss: 0.00823150
epoch: 92  batch: 17  loss: 0.00696047
epoch: 92  batch: 18  loss: 0.00584631
epoch: 92  batch: 19  loss: 0.00273301
epoch: 92  batch: 20  loss: 0.00673860
epoch: 92  batch: 21  loss: 0.00393938
epoch: 92  batch: 22  loss: 0.0040

epoch: 99  batch: 13  loss: 0.00509529
epoch: 99  batch: 14  loss: 0.00346363
epoch: 99  batch: 15  loss: 0.00452232
epoch: 99  batch: 16  loss: 0.00331137
epoch: 99  batch: 17  loss: 0.00186720
epoch: 99  batch: 18  loss: 0.02041185
epoch: 99  batch: 19  loss: 0.00133640
epoch: 99  batch: 20  loss: 0.00421229
epoch: 99  batch: 21  loss: 0.00228180
epoch: 99  batch: 22  loss: 0.00283300
epoch: 99  batch: 23  loss: 0.00301277
epoch: 99  batch: 24  loss: 0.00507246
epoch: 99  batch: 25  loss: 0.00364106
epoch: 99  batch: 26  loss: 0.00638569
epoch: 99  batch: 27  loss: 0.00743264
epoch: 99  batch: 28  loss: 0.00282645
epoch: 100  batch: 1  loss: 0.00381650
epoch: 100  batch: 2  loss: 0.01082819
epoch: 100  batch: 3  loss: 0.00286844
epoch: 100  batch: 4  loss: 0.01619374
epoch: 100  batch: 5  loss: 0.00449108
epoch: 100  batch: 6  loss: 0.00329164
epoch: 100  batch: 7  loss: 0.00558647
epoch: 100  batch: 8  loss: 0.00285734
epoch: 100  batch: 9  loss: 0.00718769
epoch: 100  batch: 10  lo

epoch: 106  batch: 24  loss: 0.00277173
epoch: 106  batch: 25  loss: 0.01487436
epoch: 106  batch: 26  loss: 0.00778734
epoch: 106  batch: 27  loss: 0.01606144
epoch: 106  batch: 28  loss: 0.00646652
epoch: 107  batch: 1  loss: 0.00501340
epoch: 107  batch: 2  loss: 0.00529428
epoch: 107  batch: 3  loss: 0.00812664
epoch: 107  batch: 4  loss: 0.00694290
epoch: 107  batch: 5  loss: 0.00333715
epoch: 107  batch: 6  loss: 0.02610985
epoch: 107  batch: 7  loss: 0.00724509
epoch: 107  batch: 8  loss: 0.00357293
epoch: 107  batch: 9  loss: 0.00143576
epoch: 107  batch: 10  loss: 0.00455629
epoch: 107  batch: 11  loss: 0.00365466
epoch: 107  batch: 12  loss: 0.00291636
epoch: 107  batch: 13  loss: 0.00204988
epoch: 107  batch: 14  loss: 0.01111590
epoch: 107  batch: 15  loss: 0.00660811
epoch: 107  batch: 16  loss: 0.00485716
epoch: 107  batch: 17  loss: 0.00480672
epoch: 107  batch: 18  loss: 0.00236820
epoch: 107  batch: 19  loss: 0.00305879
epoch: 107  batch: 20  loss: 0.00220062
epoch: 10

epoch: 114  batch: 7  loss: 0.01283375
epoch: 114  batch: 8  loss: 0.00214468
epoch: 114  batch: 9  loss: 0.00320999
epoch: 114  batch: 10  loss: 0.00686347
epoch: 114  batch: 11  loss: 0.03332075
epoch: 114  batch: 12  loss: 0.00426674
epoch: 114  batch: 13  loss: 0.00446182
epoch: 114  batch: 14  loss: 0.00709551
epoch: 114  batch: 15  loss: 0.00508754
epoch: 114  batch: 16  loss: 0.00507665
epoch: 114  batch: 17  loss: 0.00233556
epoch: 114  batch: 18  loss: 0.02183567
epoch: 114  batch: 19  loss: 0.00345576
epoch: 114  batch: 20  loss: 0.00264109
epoch: 114  batch: 21  loss: 0.00953281
epoch: 114  batch: 22  loss: 0.00530715
epoch: 114  batch: 23  loss: 0.01133305
epoch: 114  batch: 24  loss: 0.00343965
epoch: 114  batch: 25  loss: 0.00200918
epoch: 114  batch: 26  loss: 0.00289861
epoch: 114  batch: 27  loss: 0.00446053
epoch: 114  batch: 28  loss: 0.00373773
epoch: 115  batch: 1  loss: 0.00266231
epoch: 115  batch: 2  loss: 0.00596082
epoch: 115  batch: 3  loss: 0.00429270
epoch:

epoch: 121  batch: 18  loss: 0.00460192
epoch: 121  batch: 19  loss: 0.00429834
epoch: 121  batch: 20  loss: 0.00337144
epoch: 121  batch: 21  loss: 0.00675612
epoch: 121  batch: 22  loss: 0.00390243
epoch: 121  batch: 23  loss: 0.00308335
epoch: 121  batch: 24  loss: 0.00398126
epoch: 121  batch: 25  loss: 0.00842862
epoch: 121  batch: 26  loss: 0.00625966
epoch: 121  batch: 27  loss: 0.00406267
epoch: 121  batch: 28  loss: 0.00519278
epoch: 122  batch: 1  loss: 0.00232296
epoch: 122  batch: 2  loss: 0.00364715
epoch: 122  batch: 3  loss: 0.00325110
epoch: 122  batch: 4  loss: 0.00310315
epoch: 122  batch: 5  loss: 0.00281167
epoch: 122  batch: 6  loss: 0.00131193
epoch: 122  batch: 7  loss: 0.00227107
epoch: 122  batch: 8  loss: 0.01964162
epoch: 122  batch: 9  loss: 0.00234713
epoch: 122  batch: 10  loss: 0.00684764
epoch: 122  batch: 11  loss: 0.00201492
epoch: 122  batch: 12  loss: 0.00980014
epoch: 122  batch: 13  loss: 0.00520404
epoch: 122  batch: 14  loss: 0.02089509
epoch: 12

epoch: 129  batch: 1  loss: 0.00343619
epoch: 129  batch: 2  loss: 0.00535840
epoch: 129  batch: 3  loss: 0.01687443
epoch: 129  batch: 4  loss: 0.00285096
epoch: 129  batch: 5  loss: 0.00265339
epoch: 129  batch: 6  loss: 0.00111544
epoch: 129  batch: 7  loss: 0.00168272
epoch: 129  batch: 8  loss: 0.01712511
epoch: 129  batch: 9  loss: 0.00237027
epoch: 129  batch: 10  loss: 0.00363151
epoch: 129  batch: 11  loss: 0.00249726
epoch: 129  batch: 12  loss: 0.00379322
epoch: 129  batch: 13  loss: 0.00331148
epoch: 129  batch: 14  loss: 0.00633012
epoch: 129  batch: 15  loss: 0.00633982
epoch: 129  batch: 16  loss: 0.00161761
epoch: 129  batch: 17  loss: 0.00817306
epoch: 129  batch: 18  loss: 0.00332125
epoch: 129  batch: 19  loss: 0.00154419
epoch: 129  batch: 20  loss: 0.00152545
epoch: 129  batch: 21  loss: 0.00646288
epoch: 129  batch: 22  loss: 0.00288229
epoch: 129  batch: 23  loss: 0.00524431
epoch: 129  batch: 24  loss: 0.00356393
epoch: 129  batch: 25  loss: 0.01581565
epoch: 12

epoch: 136  batch: 12  loss: 0.00334623
epoch: 136  batch: 13  loss: 0.00284733
epoch: 136  batch: 14  loss: 0.00259497
epoch: 136  batch: 15  loss: 0.00212701
epoch: 136  batch: 16  loss: 0.00498746
epoch: 136  batch: 17  loss: 0.01333829
epoch: 136  batch: 18  loss: 0.00381670
epoch: 136  batch: 19  loss: 0.00461509
epoch: 136  batch: 20  loss: 0.00212902
epoch: 136  batch: 21  loss: 0.00525870
epoch: 136  batch: 22  loss: 0.00226432
epoch: 136  batch: 23  loss: 0.00233431
epoch: 136  batch: 24  loss: 0.00278073
epoch: 136  batch: 25  loss: 0.00450077
epoch: 136  batch: 26  loss: 0.00380504
epoch: 136  batch: 27  loss: 0.00350356
epoch: 136  batch: 28  loss: 0.00198718
epoch: 137  batch: 1  loss: 0.00331358
epoch: 137  batch: 2  loss: 0.00178232
epoch: 137  batch: 3  loss: 0.00680232
epoch: 137  batch: 4  loss: 0.00169383
epoch: 137  batch: 5  loss: 0.00215490
epoch: 137  batch: 6  loss: 0.00220461
epoch: 137  batch: 7  loss: 0.02134942
epoch: 137  batch: 8  loss: 0.00193842
epoch: 1

epoch: 143  batch: 23  loss: 0.00284842
epoch: 143  batch: 24  loss: 0.00429032
epoch: 143  batch: 25  loss: 0.00276147
epoch: 143  batch: 26  loss: 0.00405249
epoch: 143  batch: 27  loss: 0.00775520
epoch: 143  batch: 28  loss: 0.00536492
epoch: 144  batch: 1  loss: 0.00156705
epoch: 144  batch: 2  loss: 0.00251963
epoch: 144  batch: 3  loss: 0.00062412
epoch: 144  batch: 4  loss: 0.00380820
epoch: 144  batch: 5  loss: 0.00281605
epoch: 144  batch: 6  loss: 0.00944375
epoch: 144  batch: 7  loss: 0.01798381
epoch: 144  batch: 8  loss: 0.00133262
epoch: 144  batch: 9  loss: 0.00185799
epoch: 144  batch: 10  loss: 0.00235114
epoch: 144  batch: 11  loss: 0.00606359
epoch: 144  batch: 12  loss: 0.01220179
epoch: 144  batch: 13  loss: 0.00114561
epoch: 144  batch: 14  loss: 0.00234336
epoch: 144  batch: 15  loss: 0.00536424
epoch: 144  batch: 16  loss: 0.00735407
epoch: 144  batch: 17  loss: 0.00806700
epoch: 144  batch: 18  loss: 0.00643826
epoch: 144  batch: 19  loss: 0.00209478
epoch: 14

epoch: 151  batch: 6  loss: 0.00374786
epoch: 151  batch: 7  loss: 0.00222573
epoch: 151  batch: 8  loss: 0.00229754
epoch: 151  batch: 9  loss: 0.00164711
epoch: 151  batch: 10  loss: 0.00127863
epoch: 151  batch: 11  loss: 0.00592621
epoch: 151  batch: 12  loss: 0.00214311
epoch: 151  batch: 13  loss: 0.00123432
epoch: 151  batch: 14  loss: 0.00929856
epoch: 151  batch: 15  loss: 0.00146576
epoch: 151  batch: 16  loss: 0.00166697
epoch: 151  batch: 17  loss: 0.00490740
epoch: 151  batch: 18  loss: 0.00509938
epoch: 151  batch: 19  loss: 0.00074131
epoch: 151  batch: 20  loss: 0.00245778
epoch: 151  batch: 21  loss: 0.00581743
epoch: 151  batch: 22  loss: 0.00347978
epoch: 151  batch: 23  loss: 0.00669208
epoch: 151  batch: 24  loss: 0.00555242
epoch: 151  batch: 25  loss: 0.01137585
epoch: 151  batch: 26  loss: 0.00432497
epoch: 151  batch: 27  loss: 0.01407775
epoch: 151  batch: 28  loss: 0.00575106
epoch: 152  batch: 1  loss: 0.00650188
epoch: 152  batch: 2  loss: 0.00901834
epoch:

epoch: 158  batch: 17  loss: 0.00207977
epoch: 158  batch: 18  loss: 0.00775211
epoch: 158  batch: 19  loss: 0.00621784
epoch: 158  batch: 20  loss: 0.01056100
epoch: 158  batch: 21  loss: 0.00244698
epoch: 158  batch: 22  loss: 0.00318478
epoch: 158  batch: 23  loss: 0.00392020
epoch: 158  batch: 24  loss: 0.01710017
epoch: 158  batch: 25  loss: 0.00366205
epoch: 158  batch: 26  loss: 0.00275046
epoch: 158  batch: 27  loss: 0.00465060
epoch: 158  batch: 28  loss: 0.00882904
epoch: 159  batch: 1  loss: 0.00567884
epoch: 159  batch: 2  loss: 0.00435251
epoch: 159  batch: 3  loss: 0.01993708
epoch: 159  batch: 4  loss: 0.00506807
epoch: 159  batch: 5  loss: 0.00387960
epoch: 159  batch: 6  loss: 0.00128910
epoch: 159  batch: 7  loss: 0.00094756
epoch: 159  batch: 8  loss: 0.00101556
epoch: 159  batch: 9  loss: 0.00405334
epoch: 159  batch: 10  loss: 0.00522464
epoch: 159  batch: 11  loss: 0.00189907
epoch: 159  batch: 12  loss: 0.00133621
epoch: 159  batch: 13  loss: 0.02112888
epoch: 15

epoch: 166  batch: 1  loss: 0.00103051
epoch: 166  batch: 2  loss: 0.00157034
epoch: 166  batch: 3  loss: 0.00309841
epoch: 166  batch: 4  loss: 0.00155655
epoch: 166  batch: 5  loss: 0.00296015
epoch: 166  batch: 6  loss: 0.00353720
epoch: 166  batch: 7  loss: 0.00115784
epoch: 166  batch: 8  loss: 0.00200631
epoch: 166  batch: 9  loss: 0.01832648
epoch: 166  batch: 10  loss: 0.00695322
epoch: 166  batch: 11  loss: 0.01454894
epoch: 166  batch: 12  loss: 0.00356486
epoch: 166  batch: 13  loss: 0.00222739
epoch: 166  batch: 14  loss: 0.00255991
epoch: 166  batch: 15  loss: 0.00212681
epoch: 166  batch: 16  loss: 0.00368794
epoch: 166  batch: 17  loss: 0.00601977
epoch: 166  batch: 18  loss: 0.00530006
epoch: 166  batch: 19  loss: 0.00280116
epoch: 166  batch: 20  loss: 0.00323224
epoch: 166  batch: 21  loss: 0.00265100
epoch: 166  batch: 22  loss: 0.00553639
epoch: 166  batch: 23  loss: 0.00300054
epoch: 166  batch: 24  loss: 0.00510846
epoch: 166  batch: 25  loss: 0.00333284
epoch: 16

epoch: 173  batch: 12  loss: 0.00465865
epoch: 173  batch: 13  loss: 0.00512879
epoch: 173  batch: 14  loss: 0.00299319
epoch: 173  batch: 15  loss: 0.00530923
epoch: 173  batch: 16  loss: 0.02180788
epoch: 173  batch: 17  loss: 0.00099596
epoch: 173  batch: 18  loss: 0.00127197
epoch: 173  batch: 19  loss: 0.00293399
epoch: 173  batch: 20  loss: 0.00437767
epoch: 173  batch: 21  loss: 0.00457370
epoch: 173  batch: 22  loss: 0.00146990
epoch: 173  batch: 23  loss: 0.00172328
epoch: 173  batch: 24  loss: 0.00737618
epoch: 173  batch: 25  loss: 0.00163112
epoch: 173  batch: 26  loss: 0.00214101
epoch: 173  batch: 27  loss: 0.00456694
epoch: 173  batch: 28  loss: 0.00625858
epoch: 174  batch: 1  loss: 0.00209795
epoch: 174  batch: 2  loss: 0.00209142
epoch: 174  batch: 3  loss: 0.00158171
epoch: 174  batch: 4  loss: 0.02248824
epoch: 174  batch: 5  loss: 0.00281043
epoch: 174  batch: 6  loss: 0.01280332
epoch: 174  batch: 7  loss: 0.00427098
epoch: 174  batch: 8  loss: 0.00247918
epoch: 1

epoch: 180  batch: 23  loss: 0.00191690
epoch: 180  batch: 24  loss: 0.00217396
epoch: 180  batch: 25  loss: 0.00523293
epoch: 180  batch: 26  loss: 0.01147476
epoch: 180  batch: 27  loss: 0.01890212
epoch: 180  batch: 28  loss: 0.00467308
epoch: 181  batch: 1  loss: 0.00257196
epoch: 181  batch: 2  loss: 0.00698272
epoch: 181  batch: 3  loss: 0.00225818
epoch: 181  batch: 4  loss: 0.00393159
epoch: 181  batch: 5  loss: 0.01828599
epoch: 181  batch: 6  loss: 0.00922867
epoch: 181  batch: 7  loss: 0.00959687
epoch: 181  batch: 8  loss: 0.00307638
epoch: 181  batch: 9  loss: 0.00619453
epoch: 181  batch: 10  loss: 0.00175985
epoch: 181  batch: 11  loss: 0.00609161
epoch: 181  batch: 12  loss: 0.00647321
epoch: 181  batch: 13  loss: 0.00854537
epoch: 181  batch: 14  loss: 0.00786505
epoch: 181  batch: 15  loss: 0.00269808
epoch: 181  batch: 16  loss: 0.00380868
epoch: 181  batch: 17  loss: 0.00360632
epoch: 181  batch: 18  loss: 0.00446095
epoch: 181  batch: 19  loss: 0.00192924
epoch: 18

epoch: 188  batch: 6  loss: 0.00345494
epoch: 188  batch: 7  loss: 0.00363656
epoch: 188  batch: 8  loss: 0.00398313
epoch: 188  batch: 9  loss: 0.00512556
epoch: 188  batch: 10  loss: 0.00429054
epoch: 188  batch: 11  loss: 0.01784040
epoch: 188  batch: 12  loss: 0.00470186
epoch: 188  batch: 13  loss: 0.00414327
epoch: 188  batch: 14  loss: 0.00480407
epoch: 188  batch: 15  loss: 0.01321133
epoch: 188  batch: 16  loss: 0.00620391
epoch: 188  batch: 17  loss: 0.00076493
epoch: 188  batch: 18  loss: 0.00274350
epoch: 188  batch: 19  loss: 0.00505291
epoch: 188  batch: 20  loss: 0.00496565
epoch: 188  batch: 21  loss: 0.00347638
epoch: 188  batch: 22  loss: 0.00873775
epoch: 188  batch: 23  loss: 0.00236414
epoch: 188  batch: 24  loss: 0.00280515
epoch: 188  batch: 25  loss: 0.00541081
epoch: 188  batch: 26  loss: 0.00335616
epoch: 188  batch: 27  loss: 0.01221576
epoch: 188  batch: 28  loss: 0.00783827
epoch: 189  batch: 1  loss: 0.01503680
epoch: 189  batch: 2  loss: 0.00114878
epoch:

epoch: 195  batch: 17  loss: 0.00183538
epoch: 195  batch: 18  loss: 0.00559759
epoch: 195  batch: 19  loss: 0.00481666
epoch: 195  batch: 20  loss: 0.00347295
epoch: 195  batch: 21  loss: 0.00818957
epoch: 195  batch: 22  loss: 0.00201218
epoch: 195  batch: 23  loss: 0.00219526
epoch: 195  batch: 24  loss: 0.00937866
epoch: 195  batch: 25  loss: 0.00776727
epoch: 195  batch: 26  loss: 0.00422685
epoch: 195  batch: 27  loss: 0.00570830
epoch: 195  batch: 28  loss: 0.01671702
epoch: 196  batch: 1  loss: 0.00639903
epoch: 196  batch: 2  loss: 0.00563642
epoch: 196  batch: 3  loss: 0.00584184
epoch: 196  batch: 4  loss: 0.00736135
epoch: 196  batch: 5  loss: 0.00214780
epoch: 196  batch: 6  loss: 0.01025454
epoch: 196  batch: 7  loss: 0.01580467
epoch: 196  batch: 8  loss: 0.00649423
epoch: 196  batch: 9  loss: 0.00450439
epoch: 196  batch: 10  loss: 0.00607103
epoch: 196  batch: 11  loss: 0.00211576
epoch: 196  batch: 12  loss: 0.00451549
epoch: 196  batch: 13  loss: 0.00224715
epoch: 19

epoch: 203  batch: 1  loss: 0.00474118
epoch: 203  batch: 2  loss: 0.02228550
epoch: 203  batch: 3  loss: 0.00264210
epoch: 203  batch: 4  loss: 0.00411468
epoch: 203  batch: 5  loss: 0.00190903
epoch: 203  batch: 6  loss: 0.02348201
epoch: 203  batch: 7  loss: 0.00360638
epoch: 203  batch: 8  loss: 0.00153996
epoch: 203  batch: 9  loss: 0.00329842
epoch: 203  batch: 10  loss: 0.00544254
epoch: 203  batch: 11  loss: 0.00476652
epoch: 203  batch: 12  loss: 0.00199273
epoch: 203  batch: 13  loss: 0.00232144
epoch: 203  batch: 14  loss: 0.00234711
epoch: 203  batch: 15  loss: 0.00487207
epoch: 203  batch: 16  loss: 0.00205338
epoch: 203  batch: 17  loss: 0.00409488
epoch: 203  batch: 18  loss: 0.00270345
epoch: 203  batch: 19  loss: 0.00547857
epoch: 203  batch: 20  loss: 0.00351307
epoch: 203  batch: 21  loss: 0.00093060
epoch: 203  batch: 22  loss: 0.00246333
epoch: 203  batch: 23  loss: 0.00572000
epoch: 203  batch: 24  loss: 0.01042640
epoch: 203  batch: 25  loss: 0.00876263
epoch: 20

epoch: 210  batch: 12  loss: 0.00266710
epoch: 210  batch: 13  loss: 0.00610555
epoch: 210  batch: 14  loss: 0.00405738
epoch: 210  batch: 15  loss: 0.00515872
epoch: 210  batch: 16  loss: 0.00857773
epoch: 210  batch: 17  loss: 0.00422919
epoch: 210  batch: 18  loss: 0.00237130
epoch: 210  batch: 19  loss: 0.02358142
epoch: 210  batch: 20  loss: 0.00416427
epoch: 210  batch: 21  loss: 0.00510885
epoch: 210  batch: 22  loss: 0.00292818
epoch: 210  batch: 23  loss: 0.00232503
epoch: 210  batch: 24  loss: 0.01154174
epoch: 210  batch: 25  loss: 0.02182749
epoch: 210  batch: 26  loss: 0.00216792
epoch: 210  batch: 27  loss: 0.00247769
epoch: 210  batch: 28  loss: 0.00225156
epoch: 211  batch: 1  loss: 0.00642724
epoch: 211  batch: 2  loss: 0.00533446
epoch: 211  batch: 3  loss: 0.00784175
epoch: 211  batch: 4  loss: 0.00194356
epoch: 211  batch: 5  loss: 0.00326909
epoch: 211  batch: 6  loss: 0.00481241
epoch: 211  batch: 7  loss: 0.00370709
epoch: 211  batch: 8  loss: 0.00496689
epoch: 2

epoch: 217  batch: 23  loss: 0.00138590
epoch: 217  batch: 24  loss: 0.00400470
epoch: 217  batch: 25  loss: 0.00082338
epoch: 217  batch: 26  loss: 0.00242047
epoch: 217  batch: 27  loss: 0.00621947
epoch: 217  batch: 28  loss: 0.00913760
epoch: 218  batch: 1  loss: 0.01364012
epoch: 218  batch: 2  loss: 0.00246996
epoch: 218  batch: 3  loss: 0.00555656
epoch: 218  batch: 4  loss: 0.00240500
epoch: 218  batch: 5  loss: 0.00589209
epoch: 218  batch: 6  loss: 0.00372585
epoch: 218  batch: 7  loss: 0.00521355
epoch: 218  batch: 8  loss: 0.00158989
epoch: 218  batch: 9  loss: 0.00363951
epoch: 218  batch: 10  loss: 0.00365461
epoch: 218  batch: 11  loss: 0.00257936
epoch: 218  batch: 12  loss: 0.00173797
epoch: 218  batch: 13  loss: 0.00311718
epoch: 218  batch: 14  loss: 0.00333543
epoch: 218  batch: 15  loss: 0.00542112
epoch: 218  batch: 16  loss: 0.00737910
epoch: 218  batch: 17  loss: 0.00332221
epoch: 218  batch: 18  loss: 0.00386110
epoch: 218  batch: 19  loss: 0.00173345
epoch: 21

epoch: 225  batch: 6  loss: 0.00468394
epoch: 225  batch: 7  loss: 0.02123985
epoch: 225  batch: 8  loss: 0.01476820
epoch: 225  batch: 9  loss: 0.00700381
epoch: 225  batch: 10  loss: 0.00455163
epoch: 225  batch: 11  loss: 0.00757772
epoch: 225  batch: 12  loss: 0.00569533
epoch: 225  batch: 13  loss: 0.00197219
epoch: 225  batch: 14  loss: 0.00316611
epoch: 225  batch: 15  loss: 0.00454642
epoch: 225  batch: 16  loss: 0.00437483
epoch: 225  batch: 17  loss: 0.00973735
epoch: 225  batch: 18  loss: 0.00739588
epoch: 225  batch: 19  loss: 0.00288890
epoch: 225  batch: 20  loss: 0.00969771
epoch: 225  batch: 21  loss: 0.00201379
epoch: 225  batch: 22  loss: 0.00234213
epoch: 225  batch: 23  loss: 0.00435830
epoch: 225  batch: 24  loss: 0.00125655
epoch: 225  batch: 25  loss: 0.00043669
epoch: 225  batch: 26  loss: 0.00442627
epoch: 225  batch: 27  loss: 0.00435637
epoch: 225  batch: 28  loss: 0.00388112
epoch: 226  batch: 1  loss: 0.00681695
epoch: 226  batch: 2  loss: 0.00926599
epoch:

epoch: 232  batch: 17  loss: 0.00273762
epoch: 232  batch: 18  loss: 0.01251213
epoch: 232  batch: 19  loss: 0.00159747
epoch: 232  batch: 20  loss: 0.00451995
epoch: 232  batch: 21  loss: 0.00402270
epoch: 232  batch: 22  loss: 0.00423519
epoch: 232  batch: 23  loss: 0.00243320
epoch: 232  batch: 24  loss: 0.00180809
epoch: 232  batch: 25  loss: 0.00287524
epoch: 232  batch: 26  loss: 0.01089550
epoch: 232  batch: 27  loss: 0.00390047
epoch: 232  batch: 28  loss: 0.00720427
epoch: 233  batch: 1  loss: 0.00362789
epoch: 233  batch: 2  loss: 0.00851104
epoch: 233  batch: 3  loss: 0.00361556
epoch: 233  batch: 4  loss: 0.00240242
epoch: 233  batch: 5  loss: 0.00460789
epoch: 233  batch: 6  loss: 0.00467869
epoch: 233  batch: 7  loss: 0.02385299
epoch: 233  batch: 8  loss: 0.00996107
epoch: 233  batch: 9  loss: 0.00090099
epoch: 233  batch: 10  loss: 0.00430845
epoch: 233  batch: 11  loss: 0.00267027
epoch: 233  batch: 12  loss: 0.00236450
epoch: 233  batch: 13  loss: 0.00385709
epoch: 23

epoch: 240  batch: 1  loss: 0.00155367
epoch: 240  batch: 2  loss: 0.01246096
epoch: 240  batch: 3  loss: 0.00448113
epoch: 240  batch: 4  loss: 0.00168005
epoch: 240  batch: 5  loss: 0.00371817
epoch: 240  batch: 6  loss: 0.00327788
epoch: 240  batch: 7  loss: 0.00697038
epoch: 240  batch: 8  loss: 0.00306899
epoch: 240  batch: 9  loss: 0.00861995
epoch: 240  batch: 10  loss: 0.00462668
epoch: 240  batch: 11  loss: 0.00379119
epoch: 240  batch: 12  loss: 0.02998002
epoch: 240  batch: 13  loss: 0.00394808
epoch: 240  batch: 14  loss: 0.00580151
epoch: 240  batch: 15  loss: 0.00456455
epoch: 240  batch: 16  loss: 0.00108965
epoch: 240  batch: 17  loss: 0.00750567
epoch: 240  batch: 18  loss: 0.00497876
epoch: 240  batch: 19  loss: 0.00953124
epoch: 240  batch: 20  loss: 0.00641900
epoch: 240  batch: 21  loss: 0.00816988
epoch: 240  batch: 22  loss: 0.00945110
epoch: 240  batch: 23  loss: 0.00332543
epoch: 240  batch: 24  loss: 0.00541315
epoch: 240  batch: 25  loss: 0.00557011
epoch: 24

epoch: 247  batch: 12  loss: 0.01356728
epoch: 247  batch: 13  loss: 0.02968577
epoch: 247  batch: 14  loss: 0.00126482
epoch: 247  batch: 15  loss: 0.00650626
epoch: 247  batch: 16  loss: 0.00304634
epoch: 247  batch: 17  loss: 0.01316819
epoch: 247  batch: 18  loss: 0.00413135
epoch: 247  batch: 19  loss: 0.01058752
epoch: 247  batch: 20  loss: 0.01179434
epoch: 247  batch: 21  loss: 0.00350121
epoch: 247  batch: 22  loss: 0.00791127
epoch: 247  batch: 23  loss: 0.00420343
epoch: 247  batch: 24  loss: 0.00116851
epoch: 247  batch: 25  loss: 0.00375016
epoch: 247  batch: 26  loss: 0.00206638
epoch: 247  batch: 27  loss: 0.00332132
epoch: 247  batch: 28  loss: 0.00939127
epoch: 248  batch: 1  loss: 0.00514365
epoch: 248  batch: 2  loss: 0.00188521
epoch: 248  batch: 3  loss: 0.00483537
epoch: 248  batch: 4  loss: 0.00095546
epoch: 248  batch: 5  loss: 0.00269765
epoch: 248  batch: 6  loss: 0.01010516
epoch: 248  batch: 7  loss: 0.00569417
epoch: 248  batch: 8  loss: 0.01076002
epoch: 2

In [23]:
# Save the model
torch.save(Model.state_dict(), 'AoR_MODEL6D_NEW2.pt')

In [20]:
plt.plot(train_losses, label='training loss')
plt.plot(validation_losses, label='validation loss')
plt.legend();
plt.grid()
plt.xlabel("Epoch")
plt.ylabel("Loss")

NameError: name 'train_losses' is not defined

# Cheking 

In [21]:
# Functions 
def CreatePointCloud(color_im, depth_im):
    color_raw = o3d.geometry.Image(color_im)
    depth_raw = o3d.geometry.Image(depth_im)
    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(color_raw, depth_raw, 1000) # 
    PointCloud = o3d.geometry.PointCloud.create_from_rgbd_image(
      rgbd_image,o3d.camera.PinholeCameraIntrinsic(o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault)) # Creates Point Cloud from rgbd image
    PointCloud.transform([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) # Flip it, otherwise the pointcloud will be upside down
    return PointCloud

def pick_points(pcd):
    vis = o3d.visualization.VisualizerWithEditing()
    vis.create_window()
    vis.add_geometry(pcd)
    vis.run()
    vis.destroy_window()
    numpy_array=np.asarray(pcd.points)
    point_id=vis.get_picked_points()

    return [numpy_array[point_id[0]],numpy_array[point_id[1]]]

def draw_arrow(pcd, points_real, points_extimated):
    lines=[[0,1],[2,3]]
    points = np.concatenate((points_real, points_extimated), axis=0)
    colors = [[1,0,0],[0,1,0]] # Red is REAL and Green is ESTIMATED
    line_set = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(points),
        lines=o3d.utility.Vector2iVector(lines),

    )
    line_set.colors=o3d.utility.Vector3dVector(colors)
    o3d.visualization.draw_geometries([pcd,line_set])

In [22]:
inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 
                                     std=[1/0.229, 1/0.224, 1/0.225])

inv_resize = transforms.Resize(480)

In [24]:
Model = AoRNet()
Model.load_state_dict(torch.load('AoR_MODEL6D_NEW2.pt'))
Model.eval()

# torch.manual_seed(101)
with torch.no_grad():
    for b, (X_validation, y_validation) in enumerate(validation_loader):
#         Apply the model
        y_val = Model(X_validation)
#         print(y_val.shape)
        for j in range(y_val.shape[0]):
            X_invNorm = inv_resize(X_validation[j])
            RGB_buff = inv_normalize(torch.stack((X_invNorm[0],X_invNorm[1],X_invNorm[2]))).numpy()*255
#             RGB_buff = np.stack((X_invNorm[0].numpy(),X_invNorm[1].numpy(),X_invNorm[2].numpy()))*255
            RGB_buff = np.transpose(RGB_buff, (1,2,0))
            RGB_buff = np.ascontiguousarray(RGB_buff, dtype=np.uint8)

            DEPTH_buff = X_invNorm[3].numpy()*65535
            PC = CreatePointCloud(RGB_buff, DEPTH_buff)
            PREDICTED = [[y_val[j][0].cpu().numpy(), y_val[j][1].cpu().numpy(), y_val[j][2].cpu().numpy()],
                         [y_val[j][3].cpu().numpy(), y_val[j][4].cpu().numpy(), y_val[j][5].cpu().numpy()]]
            REAL = [[y_validation[j][0].cpu().numpy(), y_validation[j][1].cpu().numpy(), y_validation[j][2].cpu().numpy()],
                    [y_validation[j][3].cpu().numpy(), y_validation[j][4].cpu().numpy(), y_validation[j][5].cpu().numpy()]]
            draw_arrow(PC, REAL, PREDICTED)

            print(f'--> BATCH: {b+1} <-- | --> ROW: {j} <--')
            print(f'----------------------------------------------------------------------------------------------')
            print(f'{"X1":>12} {"Y1":>12} {"Z1":>12} {"X2":>12} {"Y2":>12} {"Z2":>12}')
            print(f'{"PREDICTED:"}')
            print(f'[[{y_val[j][0]:12.5f}, {y_val[j][1]:12.5f}, {y_val[j][2]:12.5f}], [{y_val[j][3]:12.5f}, {y_val[j][4]:12.5f}, {y_val[j][5]:12.5f}]]')
            print(f'{"REAL:"}')
            print(f'[[{y_validation[j][0]:12.5f}, {y_validation[j][1]:12.5f}, {y_validation[j][2]:12.5f}], [{y_validation[j][3]:12.5f}, {y_validation[j][4]:12.5f}, {y_validation[j][5]:12.5f}]]')
            print(f'{"DIFFERENCE:"}')
            diff = np.abs(y_val.cpu().numpy()-y_validation.cpu().numpy())
            print(f'[[{diff[j][0]:12.5f}, {diff[j][1]:12.5f}, {diff[j][2]:12.5f}], [{diff[j][3]:12.5f}, {diff[j][4]:12.5f}, {diff[j][5]:12.5f}]]')
            print(f'----------------------------------------------------------------------------------------------')
# loss = criterion(y_val, y_validation.cuda())
# diff = np.abs(y_val.cpu().numpy()-y_validation.cpu().numpy())
# print(f'RMSE: {loss:.8f}')


--> BATCH: 1 <-- | --> ROW: 0 <--
----------------------------------------------------------------------------------------------
          X1           Y1           Z1           X2           Y2           Z2
PREDICTED:
[[    -0.26318,      0.66734,     -1.58974], [    -0.23421,     -0.74804,     -1.64120]]
REAL:
[[    -0.10797,      0.54656,     -1.41700], [    -0.09271,     -0.63447,     -1.52100]]
DIFFERENCE:
[[     0.15521,      0.12079,      0.17274], [     0.14150,      0.11357,      0.12020]]
----------------------------------------------------------------------------------------------
--> BATCH: 2 <-- | --> ROW: 0 <--
----------------------------------------------------------------------------------------------
          X1           Y1           Z1           X2           Y2           Z2
PREDICTED:
[[     0.02949,      0.66142,     -3.11824], [     0.07687,     -1.10354,     -2.71482]]
REAL:
[[    -0.00797,      0.61028,     -2.77400], [     0.03546,     -1.03771,     -2.48200]]


--> BATCH: 15 <-- | --> ROW: 0 <--
----------------------------------------------------------------------------------------------
          X1           Y1           Z1           X2           Y2           Z2
PREDICTED:
[[    -0.26697,      0.77940,     -1.85090], [    -0.22453,     -0.79058,     -1.90878]]
REAL:
[[    -0.19092,      0.68764,     -1.77400], [    -0.13623,     -0.76245,     -1.83600]]
DIFFERENCE:
[[     0.07605,      0.09176,      0.07690], [     0.08829,      0.02813,      0.07278]]
----------------------------------------------------------------------------------------------
--> BATCH: 16 <-- | --> ROW: 0 <--
----------------------------------------------------------------------------------------------
          X1           Y1           Z1           X2           Y2           Z2
PREDICTED:
[[     0.66832,     -0.02169,     -1.26542], [     0.66035,     -0.26473,     -1.30080]]
REAL:
[[     0.44540,      0.13033,     -1.19000], [     0.44683,     -0.33906,     -1.24450]

In [None]:
Unet