In [None]:
import torch
import os
from torch.utils.data import Dataset,DataLoader,random_split
import torch.nn as nn
from transformers import SwinForImageClassification, Trainer, TrainingArguments
from transformers import Trainer, TrainingArguments
from transformers import AutoFeatureExtractor
from PIL import Image
import numpy as np
import torchvision
import matplotlib.pyplot as plt
from torchvision import models

# Dataset Class

In [None]:
class PlantDataset(Dataset):
    def __init__(self,imgs_path,n_x,feature_extractor):
        super().__init__()
        self.n_x = n_x
        self.feature_extractor = feature_extractor
        self.imgs_list = []
        for class_name in os.listdir(imgs_path):
            class_path = os.path.join(imgs_path,class_name)
            imgs_name = [os.path.join(class_path,img_name) for img_name in os.listdir(class_path)]
            imgs_name.sort()
            new_imgs_name = imgs_name.copy()
            if len(imgs_name) < n_x+1:
                d = (n_x + 1) // len(imgs_name) - 1
                r = (n_x + 1) % len(imgs_name)
                i = 0
                for idx in range(len(imgs_name)):
                    for fre in ragne(d):
                        new_imgs_name.insert(i,imgs_name[idx])
                    i += d + 1 
                for idx in range(r):
                    new_imgs_name.insert(len(imgs_name),imgs_name[-1])
            for idx in range(len(new_imgs_name) - n_x):
                self.imgs_list.append(new_imgs_name[idx:idx + n_x + 1])
        
    def __len__(self):
        return len(self.imgs_list)
    
    def __getitem__(self,idx):
        img_X = [Image.open(img_path) for img_path in self.imgs_list[idx][:-1]]
        img_X = self.feature_extractor([img.convert('RGB') for img in img_X],return_tensor = 'pt').pixel_values
        img_X = torch.from_numpy(np.stack(img_X))
        img_Y = self.feature_extractor(Image.open(self.imgs_list[idx][-1]).convert('RGB'),return_tensor = 'pt').pixel_values
        img_Y = torch.from_numpy(img_Y[0])
        return img_X,img_Y


# Plants-Grow Network

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.conv1 = nn.ConvTranspose2d(16, 32, 2, 2, 1)
        self.conv2 = nn.ConvTranspose2d(32, 64, 4, 2, 1)
        self.conv3 = nn.ConvTranspose2d(64, 128, 4, 2, 1)
        self.conv4 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.conv5 = nn.ConvTranspose2d(64, 3, 4, 2, 1)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        return self.tanh(x)

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return x

    
class MyModel(nn.Module):
    def __init__(self,in_feature,pretrain_swin):
        super().__init__()
        self.swin_model = SwinForImageClassification.from_pretrained(pretrain_swin)
        self.swin_model.classifier = Identity() 
        self.swin_model.eval()
        self.in_feature = in_feature
        self.lstm_model = nn.LSTM(in_feature,in_feature,batch_first = True)
        self.deconv2D = Decoder()
        self.map_model = nn.Sequential(
            nn.Linear(1024,256),
            nn.ReLU(),
            nn.Linear(256,1024)
        )
        

    def forward(self,x):
        """
        x : [B,N,C,224,224]
        """
        b,n,c,h,w = x.shape
        x = x.reshape(-1,c,h,w)
        with torch.no_grad():
            x_ = self.swin_model(x).logits # B*N,1024
        
        x_ = x_.reshape(b,n,-1)
        
        x_lstm,_ = self.lstm_model(x_) # B,N,1024
        x_lstm = x_lstm[:,-1] # B,1024
        x_encoded = self.map_model(x_lstm).reshape(b,16,8,8)
        output = self.deconv2D(x_encoded) # B,3,224,224
        return output

# Simase Network

In [None]:
class Simase_network(nn.Module):
    def __init__(self,in_features,pretrained_swin,pretrained_AE):
        super().__init__()
        self.in_features = in_features
        self.swin_model = SwinForImageClassification.from_pretrained(pretrained_swin)
        self.swin_model.classifier = Identity() 
        self.swin_model.eval()
        self.AE_model = MyModel(1024,pretrained_swin)
        self.AE_model.load_state_dict(torch.load(pretrained_AE))
        self.AE_model.eval()
        self.simase_netword = nn.Sequential(
            nn.Linear(self.in_features * 2,64),
            nn.ReLU(),
            nn.Linear(64,1),
            nn.Sigmoid()
        )
    def forward(self,x,x_positive):
        b,n,c,h,w = x.shape
        
        with torch.no_grad():
            output = self.AE_model(x) # B,C,H,W
            
        x = x.reshape(-1,c,h,w) # B*N,C,H,W
        
        x = torch.concat([x,x_positive,output],dim = 0)
        with torch.no_grad():
            x_ = self.swin_model(x).logits # B*N + B + B,1024     
            
         
        anchor = x_[-b:] # B,1024
        x_positive = x_[-2*b:-b] # B,1024
        x_negative = x_[:-2*b].reshape(b,-1,self.in_features) # B,N,1024
        gap_positive = torch.concat([anchor,x_positive],dim = 1)
        gap_negative = torch.concat([anchor.unsqueeze(1).repeat(1,n,1),x_negative],dim = -1).reshape(-1,self.in_features * 2)

        labels = torch.hstack([torch.tensor([0] * len(gap_positive)),torch.tensor([1] * len(gap_negative))])
        gap = torch.concat([gap_positive,gap_negative],dim = 0)
        output_simase = self.simase_netword(gap)
        return output_simase,labels
        
        

# DataLoader

In [None]:
n_x = 3
model_name = 'microsoft/swin-base-patch4-window7-224'
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
data = PlantDataset("/kaggle/input/plants-data/Plants-Grow",n_x,feature_extractor)
train_size = int(0.8 * len(data)) 
test_size = len(data) - train_size
train_dataset, test_dataset = random_split(data, [train_size, test_size])
data_loader = DataLoader(data,batch_size = 2,shuffle = True)
train_loader = DataLoader(train_dataset,batch_size = 2,shuffle = True)
val_loader = DataLoader(data,batch_size = 2,shuffle = True)

# Train Simaese

In [None]:
def train_epoch_Simase(model,epoch,train_loader,criterion,optimizer):
    model.train()
    losses = 0.0
    true_pred = 0
    total_true_positive = 0
    total_true_negative = 0
    total_positive = 0
    total_negative = 0
    total = 0
    for idx,data in enumerate(train_loader):
        X = data[0].to('cuda')
        y = data[1].to('cuda')
        y_pred,labels = model(X,y)
        y_pred = y_pred.squeeze()
        labels = labels.float()
        labels = labels.to('cuda')
        y_hat = torch.round(y_pred)
        true_pred += (labels == y_hat).detach().cpu().sum().item()
        total += len(y_hat)
        
        loss = criterion(y_pred,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses += loss.detach().cpu().item()
        
        negative_index = torch.where(labels == 1)[0]
        positive_index = torch.where(labels == 0)[0]
        
        true_positive = torch.sum(y_hat[positive_index] == torch.tensor(0.0)).detach().cpu().item()
        true_negative = torch.sum(y_hat[negative_index] == torch.tensor(1.0)).detach().cpu().item()
        
        total_true_positive += true_positive
        total_true_negative += true_negative
        
        total_positive += len(positive_index)
        total_negative += len(negative_index)
        

        
        if idx % 1 == 0:
            print(f"Epoch:{epoch} Batch:{idx} Loss:{loss.item()}")
    print(f"Done Epoch {epoch}, Loss: {losses/len(train_loader)} Acc : {true_pred / total * 100 :.2f} Acc_pos :{total_true_positive / total_positive} Acc_neg :{total_true_negative/total_negative}")
    return losses / len(train_loader), true_pred / total,total_true_positive / total_positive, total_true_negative / total_negative

def val_epoch_Simase(model,epoch,val_loader,criterion):
    model.eval()
    losses = 0.0
    true_pred = 0
    total_true_positive = 0
    total_true_negative = 0
    total_positive = 0
    total_negative = 0
    total = 0
    for idx,data in enumerate(val_loader):
        X = data[0].to('cuda')
        y = data[1].to('cuda')
        with torch.no_grad():
            y_pred,labels = model(X,y)
        y_pred = y_pred.squeeze()
        labels = labels.float()
        labels = labels.to('cuda')
        y_hat = torch.round(y_pred)
        true_pred += (labels == y_hat).detach().cpu().sum().item()
        total += len(y_hat)        
        loss = criterion(y_pred,labels)
        losses += loss.detach().cpu().item()
        
        negative_index = torch.where(labels == 1)[0]
        positive_index = torch.where(labels == 0)[0]
        
        true_positive = torch.sum(y_hat[positive_index] == torch.tensor(0.0)).detach().cpu().item()
        true_negative = torch.sum(y_hat[negative_index] == torch.tensor(1.0)).detach().cpu().item()
        
        print('y_pred:',y_hat.detach().cpu().numpy())
        total_true_positive += true_positive
        total_true_negative += true_negative
        
        total_positive += len(positive_index)
        total_negative += len(negative_index)
        
    print(f"Done Validate Epoch {epoch}, Loss: {losses/len(val_loader)} Acc : {true_pred / total * 100 :.2f} Acc_pos :{total_true_positive / total_positive} Acc_neg :{total_true_negative/total_negative}")
    print("______________________________________________________")
    return losses/len(val_loader),true_pred / total
    

In [None]:
model_sim = Simase_network(1024, "/kaggle/input/swin-checkpoint/Checkpoint-Swin",'/kaggle/input/ae-model/model_plants_grow.pt').cuda()
optimizer = torch.optim.Adam(model_sim.parameters(),lr = 0.001)
# weights = torch.tensor([1.0] * 2 + [0.333] * 6).cuda()
criterion_BCE = nn.BCELoss()
epochs = 100
best_acc_pos = 0.0

train_losses = []
train_accs = []
train_accs_pos = []
train_accs_neg = []

for epoch in range(epochs):
    train_loss,train_acc,acc_pos,acc_neg = train_epoch_Simase(model_sim,epoch,data_loader,criterion_BCE,optimizer)
    if acc_pos > best_acc_pos:
        best_acc_pos = acc_pos
        torch.save(model_sim.state_dict(),f'model_simaese_{epoch}.pt')
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_accs_pos.append(acc_pos)
    train_accs_neg.append(acc_neg)

# Train Plants-Grow

In [None]:
def train_epoch(model,epoch,train_loader,criterion_AE,optimizer):
    model.train()
    losses_AE = 0.0
    for idx,data in enumerate(train_loader):
        X = data[0].to('cuda')
        y = data[1].to('cuda')
        y_pred =  model(X)
        loss = criterion_AE(y_pred,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses_AE += loss.detach().cpu().item()
        if idx % 1 == 0:
            print(f"Epoch:{epoch} Batch:{idx} Loss_AE:{loss.item()} ")
    print(f"Done Epoch {epoch}, Loss_AE: {losses_AE/len(train_loader)}")
    print("______________________________________________________")
def generate_pred(model,train_loader):
    model.eval()
    imgs = []
    for idx,data in enumerate(train_loader):
        X = data[0].to('cuda')
        
        y = model(X)
        imgs.append(y.cpu())
        
    imgs = torch.concat(imgs,dim = 0)[:16]
    data_img = torchvision.utils.make_grid(imgs,nrow=4)
    return data_img
        

In [None]:
epochs = 500

model = MyModel(1024,"/kaggle/input/swin-checkpoint/Checkpoint-Swin").to('cuda')
optimizer = torch.optim.Adam(model.parameters(),lr = 0.001)
criterion_AE = nn.L1Loss()
for epoch in range(epochs):
    train_epoch(model,epoch,data_loader,criterion_AE,optimizer)
    if epoch % 5 == 0:
        plt.figure()
        data_img =generate_pred(model,data_loader).detach().cpu()
        plt.imshow(data_img.permute(1,2,0))
        plt.axis('off')
        plt.show()
        plt.close()
        

In [None]:
def generate_pred_new(model,train_loader):
    model.eval()
    imgs = []
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    for idx,data in enumerate(train_loader):
        X = data[0].to('cuda')
        
        y = model(X)
        for i in range(3):  # Loop over each channel
            y[:, i, :, :] *= std[i]
            y[:, i, :, :] += mean[i]

        imgs.append(y.cpu())
    imgs = torch.concat(imgs,dim = 0)
    data_img = torchvision.utils.make_grid(imgs,nrow=4)
    return data_img

data_img = generate_pred_new(model,data_loader).detach().cpu()
plt.imshow(data_img.permute(1,2,0))

In [None]:
torch.save(model.state_dict(),'new_model_plants_grow.pt')