<a href="https://colab.research.google.com/github/giyeongyoon/3rd_AGC/blob/master/two_stage_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

참고문헌

*   Estimation of Greenhouse Lettuce Growth Indices Based on a
Two-Stage CNN Using RGB-D Images
[Data Availability](https://data.4tu.nl/articles/dataset/3rd_Autonomous_Greenhouse_Challenge_Online_Challenge_Lettuce_Images/15023088/1)

In [1]:
%%capture
!pip install albumentations==1.1.0
!pip install agml

Import libraries

In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import os
import cv2
import albumentations as A
import pandas as pd
import matplotlib.pyplot as plt
import time

Download 2021 Autonomous Greenhouse Challenge dataset

In [3]:
import agml
loader = agml.data.AgMLDataLoader('autonomous_greenhouse_regression', dataset_path = './')

Downloading autonomous_greenhouse_regression (size = 887.2 MB): 887226368it [00:58, 15203822.09it/s]                               


[AgML Download]: Extracting files for autonomous_greenhouse_regression... Done!

You have just downloaded [1mautonomous_greenhouse_regression[0m.

This dataset is licensed under the [1mCC BY-SA 4.0[0m license.
To learn more, visit: https://creativecommons.org/licenses/by-sa/4.0/

When using this dataset, please cite the following:

@misc{https://doi.org/10.4121/15023088.v1,
  doi = {10.4121/15023088.V1},
  url = {https://data.4tu.nl/articles/_/15023088/1},
  author = {Hemming,  S. (Silke) and de Zwart,  H.F. (Feije) and Elings,  A. (Anne) and bijlaard,  monique and Marrewijk,  van,  Bart and Petropoulou,  Anna},
  keywords = {Horticultural Crops,  Mechanical Engineering,  FOS: Mechanical engineering,  Artificial Intelligence and Image Processing,  FOS: Computer and information sciences,  Horticultural Production,  FOS: Agriculture,  forestry and fisheries,  Autonomous Greenhouse Challenge,  autonomous greenhouse,  Artificial Intelligence,  image processing,  computer vision,  Horti

AttributeError: ignored

Define data and output directories

In [4]:
sav_dir='model_weights/'
if not os.path.exists(sav_dir):
    os.mkdir(sav_dir)
# Comment these two lines and uncomment the next two if you've already croppped the images to another directory
RGB_Data_Dir   = './autonomous_greenhouse_regression/images/'
Depth_Data_Dir = './autonomous_greenhouse_regression/depth_images/'


# RGB_Data_Dir='./autonomous_greenhouse_regression/cropped_images/'
# Depth_Data_Dir='./autonomous_greenhouse_regression/cropped_depth_images/'


JSON_Files_Dir = './autonomous_greenhouse_regression/annotations.json'

Crop

In [5]:
# import matplotlib.pyplot as plt
min_x=650
max_x=1450
min_y=200
max_y=900
cropped_img_dir='./autonomous_greenhouse_regression/cropped_images/'

cropped_depth_img_dir='./autonomous_greenhouse_regression/cropped_depth_images/'

if not os.path.exists(cropped_img_dir):
    os.mkdir(cropped_img_dir)

if not os.path.exists(cropped_depth_img_dir):
    os.mkdir(cropped_depth_img_dir)

for im in os.listdir(RGB_Data_Dir):
    img = cv2.imread(RGB_Data_Dir+im)
    crop_img = img[min_y:max_y,min_x:max_x]
    cv2.imwrite(cropped_img_dir+im, crop_img)

for depth_im in os.listdir(Depth_Data_Dir):
    depth_img = cv2.imread(Depth_Data_Dir+depth_im, 0)
    crop_depth_img = depth_img[min_y:max_y,min_x:max_x]
    cv2.imwrite(cropped_depth_img_dir+depth_im, crop_depth_img)

RGB_Data_Dir   = cropped_img_dir
Depth_Data_Dir = cropped_depth_img_dir

Check the targets

In [6]:
# df= pd.read_json(JSON_Files_Dir)
# row = df.iloc[0]
# print(list(row['outputs']['regression']))
# print(list(row['outputs']['regression'].values()))

Create PyTorch dataset, create PyTorch dataloader, and split train/val/test

In [7]:
split_seed = 12
num_epochs = 400

In [8]:
class GreenhouseDataset(Dataset):
    def __init__(self, rgb_dir, d_dir, jsonfile_dir, rgb_transforms=None, d_transforms=None):

        self.df= pd.read_json(jsonfile_dir)
        # flatten_json is a custom function to flat the nested json files!

        self.rgb_transforms = rgb_transforms
        self.d_transforms = d_transforms
        self.rgb_dir = rgb_dir
        self.d_dir = d_dir
        self.num_outputs = len(self.df.iloc[0]['outputs']['regression'])


    def __getitem__(self, idx):
        # load images
        row=self.df.iloc[idx]

        rgb = plt.imread(self.rgb_dir+row['image'])
        depth = plt.imread(self.d_dir+row['depth_image'])
        depth = np.expand_dims(depth, 2)

        target = list(row['outputs']['regression'].values())

        #make sure your img and mask array are in this format before passing into albumentations transforms, img.shape=[H, W, C]
        if self.rgb_transforms is not None:
            aug_rgb = self.rgb_transforms(image=rgb)
            rgb = aug_rgb['image']
        elif self.d_transforms is not None:
            aug_depth = self.d_transforms(image=depth)
            depth = aug_depth['image']

        rgb = np.transpose(rgb, (2,0,1))
        depth = np.transpose(depth, (2,0,1))

        #pytorch wants a different format for the image ([C, H, W])
        rgb = torch.as_tensor(rgb, dtype=torch.float32)
        depth = torch.as_tensor(depth, dtype=torch.float32)
        target=torch.as_tensor(target, dtype=torch.float32)

        return rgb, depth, target

    def __len__(self):
        return len(self.df)

In [9]:
## FIGURE OUT HOW TO CROP ALL THE IMAGES TO GET RID OF EXTRANIOUS PIXELS
def get_transforms(train, means, stds):
    if train:
        transforms = A.Compose([
        # A.Crop(x_min=650, y_min=200, x_max=1450, y_max=900, always_apply=False, p=1.0),
        A.Flip(p=0.5),
        A.ShiftScaleRotate(always_apply=False, p=0.5, shift_limit=(-0.06, 0.06), scale_limit=(-0.1, 0.1), rotate_limit=(-5, 5), interpolation=0, border_mode=0, value=means, mask_value=None),
        A.Normalize(mean=means, std=stds, max_pixel_value=1.0, always_apply=False, p=1.0)
        ])
    else:
        transforms =  A.Compose([
        # A.Crop(x_min=650, y_min=200, x_max=1450, y_max=900, always_apply=False, p=1.0),
        A.Normalize(mean=means, std=stds, max_pixel_value=1.0, always_apply=False, p=1.0)
        ])
    return transforms

In [10]:
# Instantiate the PyTorch datalaoder the autonomous greenhouse dataset.
dataset = GreenhouseDataset(rgb_dir = RGB_Data_Dir,
                            d_dir = Depth_Data_Dir,
                            jsonfile_dir = JSON_Files_Dir,
                            rgb_transforms = get_transforms(train=False, means=[0,0,0],stds=[1,1,1]),
                            d_transforms = get_transforms(train=False, means=[0,0,0],stds=[1,1,1]))

# Remove last 50 images from training/validation set. These are the test set.
dataset.df= dataset.df.iloc[:-50]

# Split train and validation set. Stratify based on variety.
train_split, val_split = train_test_split(dataset.df,
                                          test_size = 0.2,
                                          random_state = split_seed,
                                          stratify = dataset.df['outputs'].str['classification']) #change to None if you don't have class info
train = torch.utils.data.Subset(dataset, train_split.index.tolist())
val   = torch.utils.data.Subset(dataset, val_split.index.tolist())

# Create train and validation dataloaders
train_loader = torch.utils.data.DataLoader(train, batch_size=6, num_workers=6, shuffle=True)
val_loader   = torch.utils.data.DataLoader(val,   batch_size=6, shuffle=False, num_workers=6)




Determine the mean and standard deviation of images for normalization (Only need to do once for a new dataset)

In [11]:
# this part is just to check the MEAN and STD of the dataset (dont run unless you need mu and sigma)

n_rgb = 0
n_depth = 0
mean_rgb = 0.
std_rgb = 0.
mean_depth = 0.
std_depth = 0.
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=12)
for rgb, depth, _ in dataloader:

    # Rearrange batch to be the shape of [B, C, W * H]
    rgb = rgb.view(rgb.size(0), rgb.size(1), -1)
    depth = depth.view(depth.size(0), depth.size(1), -1)
    # Update total number of images
    n_rgb += rgb.size(0)
    n_depth += depth.size(0)
    # Compute mean and std here
    mean_rgb += rgb.mean(2).sum(0)
    std_rgb += rgb.std(2).sum(0)
    mean_depth += depth.mean(2).sum(0)
    std_depth += depth.std(2).sum(0)

# Final step
mean_rgb /= n_rgb
std_rgb /= n_rgb
mean_depth /= n_depth
std_depth /= n_depth

print('Mean of RGB: '+ str(mean_rgb))
print('Standard Deviation of RGB', str(std_rgb))
print('Mean of Depth: '+ str(mean_depth))
print('Standard Deviation of Depth', str(std_depth))



Mean of RGB: tensor([0.5482, 0.4620, 0.3602])
Standard Deviation of RGB tensor([0.1639, 0.1761, 0.2659])
Mean of Depth: tensor([0.0127])
Standard Deviation of Depth tensor([0.0035])


Copy the output of the previous cells into here to avoid needing to redetermine mean and std every time

In [12]:
dataset.means = [0.5482, 0.4620, 0.3602, 0.0127]  #these values were copied from the previous cell
dataset.stds = [0.1639, 0.1761, 0.2659, 0.0035]   #copy and paste the values to avoid having
                                                  # to rerun the previous cell for every iteration

Set device

In [13]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')

Model

In [14]:
# !pip install torchsummary
# import torchsummary

# model = models.resnet50(pretrained=True)
# model = model.cuda()
# torchsummary.summary(model, (3, 224, 224))

In [35]:
class FirstStageModel(nn.Module):
    def __init__(self):
        super(FirstStageModel, self).__init__()
        # RGB Model
        self.rgb_processing_block = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
                                                  nn.Conv2d(32, 3, kernel_size=1),
                                                  nn.AdaptiveAvgPool2d((224, 224)))
        self.rgb_encoder = models.resnet50(pretrained=True)
        self.rgb_regressor = nn.Sequential(nn.ReLU(),
                                           nn.Dropout(0.5),
                                           nn.Linear(1000, 256),
                                           nn.ReLU(),
                                           nn.Dropout(0.5),
                                           nn.Linear(256, 3))


        # Depth Model
        self.depth_processing_block = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
                                         nn.Conv2d(32, 1, kernel_size=1),
                                         nn.AdaptiveAvgPool2d((224, 224)))
        self.depth_encoder = models.resnet50(pretrained=False)
        self.depth_encoder.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.depth_regressor = nn.Sequential(nn.ReLU(),
                                             nn.Dropout(0.5),
                                             nn.Linear(1000, 256),
                                             nn.ReLU(),
                                             nn.Dropout(0.5),
                                             nn.Linear(256, 1))

        self.final = nn.Sequential(nn.Dropout(0.5),
                                   nn.Linear(4, 2048),
                                   nn.ReLU(),
                                   nn.Dropout(0.5),
                                   nn.Linear(2048, 3),
                                   nn.Dropout(0.5))

    def forward(self, rgb, depth):
        rgb_out = self.rgb_processing_block(rgb)
        rgb_out = self.rgb_encoder(rgb_out)
        rgb_out = self.rgb_regressor(rgb_out)

        depth_out = self.depth_processing_block(depth)
        depth_out = self.depth_encoder(depth_out)
        output2 = self.depth_regressor(depth_out)  # height

        output1 = torch.cat([rgb_out, output2], dim=1)
        output1 = self.final(output1)  # fresh weight, dry weight, diameter

        return output1, output2

In [36]:
class SecondStageModel(nn.Module):
    def __init__(self):
        super(SecondStageModel, self).__init__()
        self.regressor1 = nn.Sequential(nn.Linear(4, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(0.5),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(0.5),
                                        nn.Linear(2048, 1))  # dry weight
        self.regressor2 = nn.Sequential(nn.Linear(4, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(0.5),
                                        nn.Linear(2048, 2048),
                                        nn.ReLU(),
                                        nn.Dropout(0.5),
                                        nn.Linear(2048, 1))  # leaf area

    def forward(self, output1, output2):
        x = torch.cat([output1, output2], dim=1)
        output3 = self.regressor1(x)
        output4 = self.regressor2(x)
        return output3, output4

In [37]:
first_stage_model = FirstStageModel()
second_stage_model = SecondStageModel()



Hyperparameter

In [38]:
lr = 0.001
epochs = 200
batch_size = 32

Loss and optimizer

In [39]:
criterion_stage1 = nn.MSELoss()
criterion_stage2 = nn.MSELoss()

optimizer_stage1 = optim.Adam(first_stage_model.parameters(), lr=lr)
optimizer_stage2 = optim.Adam(second_stage_model.parameters(), lr=lr)

NMSE Loss

In [69]:
class NMSELoss(nn.Module):
    def __init__(self):
          # super(diceloss, self).init()
        super(NMSELoss, self).__init__()
          # print('HI')
    def forward(self, pred, target):
        if target.size() != pred.size():
              raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), pred.size()))

        num=torch.sum((target-pred)**2,0)
        den=torch.sum(target**2,0)

        return torch.sum(num/den)

Train

In [40]:
def train_single_epoch(first_stage_model, second_stage_model, dataset, device,
                       criterion_stage1, criterion_stage2, optimizer_stage1, optimizer_stage2,
                       writer, epoch, train_loader):
    first_stage_model.train()
    second_stage_model.train()

    dataset.rgb_transforms = get_transforms(train=True, means=dataset.means[:3], stds=dataset.stds[:3])
    dataset.d_transforms = get_transforms(train=True, means=dataset.means[3:], stds=dataset.stds[3:])

    for i, (rgb, depth, label) in enumerate(train_loader):
        rgb = rgb.to(device)
        depth = depth.to(device)
        label = label.to(device)  # ['FreshWeightShoot', 'DryWeightShoot', 'Height', 'Diameter', 'LeafArea']

        # Forward pass - First stage
        pred1, pred2 = first_stage_model(rgb, depth)  # pred1: fresh weight, dry weight, diameter
                                                      # pred2: height

        # Forward pass - Second stage
        pred3, pred4 = second_stage_model(pred1, pred2)  # pred3: dry weight
                                                         # pred4: leaf area

        # Calculate loss
        loss_stage1 = criterion_stage1(pred1[:, [0,2]], label[:, [0,2]]) + criterion_stage1(pred2, label[:, 3])
        loss_stage2 = criterion_stage2(pred3, label[:, 1]) + criterion_stage2(pred4, label[:, 4])
        total_loss = loss_stage1 + loss_stage2

        # Backward pass and optimization
        optimizer_stage1.zero_grad()
        optimizer_stage2.zero_grad()
        total_loss.backward()
        optimizer_stage1.step()
        optimizer_stage2.step()

        print(f'Epoch {epoch+1}/{epochs}, Batch {i+1}/{len(train_loader)}, Loss1: {loss_stage1.item()}, Loss2: {loss_stage2.item()}, Total_Loss: {total_loss.item()}')
        with open('run.txt', 'a') as f:
            f.write('\n')
            f.write('Train MSE: '+ str(total_loss.tolist()))


In [41]:
def validate(first_stage_model, second_stage_model, dataset, device, sav_dir,
             criterion_stage1, criterion_stage2, writer, epoch, val_loader, best_val_loss):
    current_val_loss = 0
    # training_val_loss=0s

    first_stage_model.eval()
    second_stage_model.eval()
    print('Validating and Checkpointing!')

    dataset.rgb_transforms = get_transforms(train=True, means=dataset.means[:3], stds=dataset.stds[:3])
    dataset.d_transforms = get_transforms(train=True, means=dataset.means[3:], stds=dataset.stds[3:])

    with torch.no_grad():
        for i, (rgb, depth, label) in enumerate(val_loader):
            rgb = rgb.to(device)
            depth = depth.to(device)
            label = label.to(device)

            pred1, pred2 = first_stage_model(rgb, depth)
            pred3, pred4 = second_stage_model(pred1, pred2)

            loss_stage1 = criterion_stage1(pred1[:, [0,2]], label[:, [0,2]]) + criterion_stage1(pred2, label[:, 3])
            loss_stage2 = criterion_stage2(pred3, label[:, 1]) + criterion_stage2(pred4, label[:, 4])
            total_loss = loss_stage1 + loss_stage2
            # acc=nmse(preds.detach(), targets)
            current_val_loss = current_val_loss + total_loss.item()
            # training_val_loss=training_val_loss+loss.detach().cpu().numpy()

        # writer.add_scalar("MSE Loss/val", training_val_loss, epoch)
        writer.add_scalar("MSE Loss/val", current_val_loss, epoch)

    if current_val_loss < best_val_loss or epoch == 0:
        best_val_loss = current_val_loss
        torch.save(first_stage_model.state_dict(), sav_dir+'bestmodel1' + '.pth')
        torch.save(second_stage_model.state_dict(), sav_dir+'bestmodel2' + '.pth')
        print('Best model Saved! Val MSE: ', str(best_val_loss))
        with open('run.txt', 'a') as f:
            f.write('\n')
            f.write('Best model Saved! Val MSE: '+ str(best_val_loss))

    else:
        print('Model is not good (might be overfitting)! Current val MSE: ', str(current_val_loss), 'Best Val MSE: ', str(best_val_loss))
        with open('run.txt', 'a') as f:
            f.write('\n')
            f.write('Model is not good (might be overfitting)! Current val MSE: '+ str(current_val_loss)+ 'Best Val MSE: '+ str(best_val_loss))
    return best_val_loss

In [42]:
first_stage_model.to(device)
second_stage_model.to(device)

best_val_loss = 9999999 # initial dummy value
current_val_loss = 0

writer = SummaryWriter()
start = time.time()

for epoch in range(epochs):
    with open('run.txt', 'a') as f:
                f.write('\n')
                f.write('Epoch: '+ str(epoch + 1) + ', Time Elapsed: '+ str((time.time()-start)/60) + ' mins')
    print('Epoch: ', str(epoch + 1), ', Time Elapsed: ', str((time.time()-start)/60), ' mins')
    train_single_epoch(first_stage_model, second_stage_model, dataset, device,
                        criterion_stage1, criterion_stage2, optimizer_stage1, optimizer_stage2,
                        writer, epoch, train_loader)
    best_val_loss = validate(first_stage_model, second_stage_model, dataset, device, sav_dir,
                                criterion_stage1, criterion_stage2, writer, epoch, val_loader, best_val_loss)

Epoch:  1 , Time Elapsed:  5.9803326924641924e-06  mins


  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/200, Batch 1/45, Loss1: 12832.4775390625, Loss2: 5148742.5, Total_Loss: 5161575.0
Epoch 1/200, Batch 2/45, Loss1: 16307.259765625, Loss2: 6324108.5, Total_Loss: 6340416.0
Epoch 1/200, Batch 3/45, Loss1: 15261.4951171875, Loss2: 9695828.0, Total_Loss: 9711089.0
Epoch 1/200, Batch 4/45, Loss1: 16796.138671875, Loss2: 7872105.5, Total_Loss: 7888901.5
Epoch 1/200, Batch 5/45, Loss1: 23518.794921875, Loss2: 4470282.0, Total_Loss: 4493801.0
Epoch 1/200, Batch 6/45, Loss1: 9797.783203125, Loss2: 1576259.0, Total_Loss: 1586056.75
Epoch 1/200, Batch 7/45, Loss1: 15525.228515625, Loss2: 4783346.0, Total_Loss: 4798871.0
Epoch 1/200, Batch 8/45, Loss1: 20651.53125, Loss2: 4192548.0, Total_Loss: 4213199.5
Epoch 1/200, Batch 9/45, Loss1: 13302.3388671875, Loss2: 3445683.5, Total_Loss: 3458985.75
Epoch 1/200, Batch 10/45, Loss1: 17918.37109375, Loss2: 3037508.5, Total_Loss: 3055426.75
Epoch 1/200, Batch 11/45, Loss1: 5250.34228515625, Loss2: 4350316.5, Total_Loss: 4355567.0
Epoch 1/200, Batch

  return F.mse_loss(input, target, reduction=self.reduction)


[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
Epoch 96/200, Batch 40/45, Loss1: 12897.4775390625, Loss2: 1105860.25, Total_Loss: 1118757.75
Epoch 96/200, Batch 41/45, Loss1: 11387.0498046875, Loss2: 4215505.5, Total_Loss: 4226892.5
Epoch 96/200, Batch 42/45, Loss1: 3098.924072265625, Loss2: 1268118.25, Total_Loss: 1271217.125
Epoch 96/200, Batch 43/45, Loss1: 8371.5576171875, Loss2: 1425316.625, Total_Loss: 1433688.125
Epoch 96/200, Batch 44/45, Loss1: 8874.2685546875, Loss2: 706419.8125, Total_Loss: 715294.0625
Epoch 96/200, Batch 45/45, Loss1: 6946.01904296875, Loss2: 1392568.125, Total_Loss: 1399514.125
Validating and Checkpointing!
Model is not good (might be overfitting)! Current val MSE:  27357832.25 Best Val MSE:  23773841.5
Epoch:  97 , Time Elapsed:  32.5388007124265  mins
Epoch 97/200, Batch 1/45, Loss1: 25200.40234375, Loss2: 7221777.5, Total_Loss: 7246978.0
Epoch 97/200, Batch 2/45, Loss1: 10695.734375, Loss2: 3530443.25, Total_Loss: 3541139.0
Epoch 97/200, Batch 3/45, 

Googl colab mount

In [44]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import shutil

drive_dir = '/content/drive/MyDrive/Colab Notebooks/growth_monitoring/Estimation of Greenhouse Lettuce Growth Indices Based on a Two-Stage CNN Using RGB-D Images'
drive_dir += '/model_weights'
if not os.path.exists(drive_dir):
    os.makedirs(drive_dir)

for file in os.listdir(sav_dir):
    shutil.copy(sav_dir + '/' + file, drive_dir)

Evaluation

Define the test dataset

In [66]:
# Instantiate the PyTorch datalaoder the autonomous greenhouse dataset.
testset = GreenhouseDataset(rgb_dir = RGB_Data_Dir,
                            d_dir = Depth_Data_Dir,
                            jsonfile_dir = JSON_Files_Dir,
                            rgb_transforms = get_transforms(train=False, means=dataset.means[:3], stds=dataset.stds[:3]),
                            d_transforms = get_transforms(train=False, means=dataset.means[3:], stds=dataset.stds[3:]))

# Grab last 50 images as test dataset
testset.df = testset.df[-50:]

# Get testset_size
testset_size = testset.df.shape[0]

# Create test dataloader
test_loader = torch.utils.data.DataLoader(testset,
                                          batch_size = 50,
                                          num_workers = 0,
                                          shuffle = False)

Define loss functions for model evaluation

In [70]:
cri = NMSELoss()
mse = nn.MSELoss()

Run the evaluation Loop

In [75]:
# Evaluation loop
device=torch.device('cuda')

with torch.no_grad():


    device=torch.device('cuda')
    model1 = FirstStageModel()
    model2 = SecondStageModel()
    model1.to(device)
    model2.to(device)
    model1.load_state_dict(torch.load(sav_dir + 'bestmodel1.pth'))
    model2.load_state_dict(torch.load(sav_dir + 'bestmodel2.pth'))
    model1.eval()
    model2.eval()

    ap=torch.zeros((0,5))
    at=torch.zeros((0,5))

    for rgb, depth, targets in test_loader:
        rgb = rgb.to(device)
        depth = depth.to(device)
        targets = targets.to(device)
        pred1, pred2 = model1(rgb, depth)
        pred3, pred4 = model2(pred1, pred2)
        pred = torch.cat([pred1[:, [0]], pred3], dim=1)  #fresh weight, dry weight
        pred = torch.cat([pred, pred2], dim=1)  # fresh weight, dry weight, height
        pred = torch.cat([pred, pred1[:, [2]]], dim=1)  # fresh weight, dry weight, height, diameter
        pred = torch.cat([pred, pred4], dim=1)  # fresh weight, dry weight, height, diameter, leaf area
        # mse_loss=mse(preds, targets)
        # nmse=criterion(preds, targets)
        # nmse, pred=cri(preds, targets)
        ap=torch.cat((ap, pred.detach().cpu()), 0)
        at=torch.cat((at, targets.detach().cpu()), 0)


    print('FW MSE: ', str(mse(ap[:,0],at[:,0]).tolist()))
    print('DW MSE: ', str(mse(ap[:,1],at[:,1]).tolist()))
    print('H MSE: ', str(mse(ap[:,2],at[:,2]).tolist()))
    print('D MSE: ', str(mse(ap[:,3],at[:,3]).tolist()))
    print('LA MSE: ', str(mse(ap[:,4],at[:,4]).tolist()))
    print('Overall NMSE: ', str(cri(ap,at).tolist()))


FW MSE:  29725.91796875
DW MSE:  31.338163375854492
H MSE:  426.22515869140625
D MSE:  717.1244506835938
LA MSE:  2489424.25
Overall NMSE:  5.332054138183594
