In [None]:
import pandas as pd
import torch
import numpy as np
from PIL import Image
import glob
import os
import matplotlib.pyplot as plt

import datetime

import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
#from torchsummary import summary

from tqdm.notebook import tqdm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

print(torch.__version__)
headder = "/home/kashihara/workspace/2022MAiZM/2022MAiZM_data"

In [None]:
from alibi_detect.utils.visualize import plot_instance_score, plot_feature_outlier_image

In [None]:
class my_dataset(Dataset):
    def __init__(self, img_path,transform=None):
        
        image_paths = glob.glob(img_path + '/*.jpg')
        labels = os.path.basename(img_path)

        self.image_paths = image_paths
        self.labels = 0 
        self.transform = transform

    def __getitem__(self, index):
        path = self.image_paths[index] 
        #画像読み込み。
        img = Image.open(path)
        #transform事前処理実施
        if self.transform is not None:
            img = self.transform(img)

        label=self.labels
        image_path=self.image_paths[index]
        return img,label

    def __len__(self):
        return len(self.image_paths)

if __name__ == '__main__':
    transform = transforms.Compose([transforms.CenterCrop(1080),transforms.Resize((128,128)), transforms.ToTensor()])
    #データセット作成
    dataset = my_dataset("{}/data/train".format(headder),transform)
    #dataloader化
    dataloader = DataLoader(dataset, batch_size=8)

    test_dataset = my_dataset("{}/data/test".format(headder),transform)
    #dataloader化
    test_dataloader = DataLoader(test_dataset, batch_size=8)


In [None]:
for i in test_dataloader:
    print(i[0].shape)
    break

In [None]:

for a in dataloader:
    fig=plt.figure(figsize=(24, 12))
    for num,img in enumerate(a[0]):
        fig.add_subplot(1,len(a[0]), num+1)
        plt.imshow(img.cpu().numpy().swapaxes(0,2))
        
    plt.show()


for b in test_dataloader:
    fig=plt.figure(figsize=(24, 12))

    for num,img in enumerate(b[0]):
        fig.add_subplot(1,len(b[0]), num+1)
        plt.imshow(img.cpu().numpy().swapaxes(0,2))
    plt.show()

In [None]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 5, stride=2, padding=2)
        self.conv2 = nn.Conv2d(32, 16, 5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(16,  8, 5, stride=2, padding=2)

        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

        self.fc1_mu = nn.Linear(8 * 8 * 8 , 80)
        self.fc1_sig = nn.Linear(8 * 8 * 8 , 80)
        self.fc2 = nn.Linear(80, 8 * 8 * 8 )

        self.up_sample = nn.UpsamplingNearest2d(scale_factor=2)

        self.conv4 = nn.ConvTranspose2d(8, 32, 6, stride=2, padding=2)
        self.conv5 = nn.ConvTranspose2d(32, 16, 6, stride=2, padding=2)
        self.conv6 = nn.ConvTranspose2d(16, 3, 6, stride=2, padding=2)


    def encode(self,x):
        #print(x.shape)
        a1 = F.relu(self.conv1(x))
        #print(a1.shape)
        a2 = F.relu(self.conv2(a1))
        #print(a2.shape)
        a3 = F.relu(self.conv3(a2))
        #print(a3.shape)
        mx_poold = self.max_pool(a3)
        #print(mx_poold.shape)
        a_reshaped = mx_poold.reshape(-1, 8 * 8 * 8 )
        #print(a_reshaped.shape)

        a_mu = self.fc1_mu(a_reshaped)
        #print(a_mu.shape)
        a_logvar = self.fc1_sig(a_reshaped)
        
        return a_mu, a_logvar

    def decode(self,z):
        a3 = F.relu(self.fc2(z))
        #print(a3.shape)
        a3 = a3.reshape(-1, 8, 8, 8)
        #print(a3.shape)
        a3_upsample = self.up_sample(a3)
        #print(a3_upsample.shape)
        a4 = F.relu(self.conv4(a3_upsample))
        #print(a4.shape)
        a5 = F.relu(self.conv5(a4))
        #print(a5.shape)
        a6 = torch.sigmoid(self.conv6(a5))
        #print(a6.shape)
        return a6

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self,x):
        #print(x.shape)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:

model = Net()  
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


In [None]:
model

In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())#pow=2乗 exp()=自然対数の底

    return BCE + KLD

In [None]:
num_epochs = 10000 #学習回数
print_per = 100
model.train()

t_delta = datetime.timedelta(hours=9)
JST = datetime.timezone(t_delta, 'JST')
now = datetime.datetime.now(JST)
d_dir = now.strftime('%Y%m%d%H%M%S')

os.mkdir("{}/model/{}".format(headder,d_dir))
loss_data = []

for epoch in tqdm(range(num_epochs)):
    train_loss = 0
    print_loss = 0
    loss_record = []
    
    t_delta = datetime.timedelta(hours=9)
    JST = datetime.timezone(t_delta, 'JST')
    now = datetime.datetime.now(JST)
    d = now.strftime('%Y%m%d%H%M%S')
    
    for i, (images) in enumerate(dataloader):
        
        optimizer.zero_grad()
        images = images[0].to(device)
        
        recon_batch, mu, logvar = model(images)
        pred = model(images)

        loss = loss_function(recon_batch, images, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        print_loss += loss.item()
        optimizer.step()

        if (i%print_per == 0):
            #print("Epoch : {} , Minibatch : {} Loss = {:.4f}".format(epoch+1, i, print_loss))
            loss_record.append(print_loss)
            print_loss = 0
            
    if epoch%10 == 0:
        print("Epoch {} : Loss = ({:.4f}) ".format(epoch+1, train_loss))
        
        fig=plt.figure(figsize=(24, 12))

        for num,img in enumerate(images):

            fig.add_subplot(1,len(images), num+1)
            plt.imshow(img.cpu().data.numpy().swapaxes(0,2))    
        plt.show()
            
        a = model(images.to(device))
        fig=plt.figure(figsize=(24, 12))

        for num,img in enumerate(a[0]):

            fig.add_subplot(1,len(a[0]), num+1)
            plt.imshow(img.cpu().data.numpy().swapaxes(0,2))       
        plt.show()
            
    if epoch%100 == 0:
        torch.save(model.state_dict(), "{}/model/{}/{}_{}_.pth".format(headder,d_dir,d,epoch))
       
    
    loss_data.append(train_loss)
    #print("Epoch {} : Loss = ({:.4f}) ".format(epoch+1, train_loss))
    plt.plot(loss_data)
    plt.show()

In [None]:
model_path = '{}/model/{}/{}'.format(headder,"20220719174322","20220720204159_8200_.pth")
model = Net()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.load_state_dict(torch.load(model_path))

In [None]:
for b in dataloader:
    fig=plt.figure(figsize=(24, 12))
    for num,img in enumerate(b[0]):
        
        fig.add_subplot(1,len(b[0]), num+1)
        plt.imshow(img.cpu().numpy().swapaxes(0,2))
        
    break

for i, (inputs, _) in enumerate(dataloader, 0):
    
    a = model(inputs.to(device))
    fig=plt.figure(figsize=(24, 12))
  
    for num,img in enumerate(a[0]):
    
        fig.add_subplot(1,len(a[0]), num+1)
        plt.imshow(img.cpu().data.numpy().swapaxes(0,2))
  
    break
    
for i, (inputs, _) in enumerate(dataloader, 0):
    a = model(inputs.to(device))
    fig=plt.figure(figsize=(24, 12))
  
    for num,img in enumerate(a[0]):
    
        fig.add_subplot(1,len(a[0]), num+1)
        imim = img.cpu().data.numpy().swapaxes(0,2) - inputs[num].cpu().data.numpy().swapaxes(0,2) 
        imim = np.clip(a=imim, a_min=0, a_max=1)
        
        #imim = np.asarray(imim, dtype = int)
        plt.imshow(imim)
  
    break


for b in test_dataloader:
    fig=plt.figure(figsize=(24, 12))
    for num,img in enumerate(b[0]):
        fig.add_subplot(1,len(b[0]), num+1)
        plt.imshow(img.cpu().numpy().swapaxes(0,2))


for i, (inputs, _) in enumerate(test_dataloader, 0):
    a = model(inputs.to(device))
    fig=plt.figure(figsize=(24, 12))
    
    for num,img in enumerate(a[0]):
    
        fig.add_subplot(1,len(a[0]), num+1)
        plt.imshow(img.cpu().data.numpy().swapaxes(0,2))
        #plt.imshow(inputs[num].cpu().data.numpy().swapaxes(0,2))
        
for i, (inputs, _) in enumerate(test_dataloader, 0):
    a = model(inputs.to(device))
    fig=plt.figure(figsize=(24, 12))
    
    for num,img in enumerate(a[0]):
    
        fig.add_subplot(1,len(a[0]), num+1)
        
        #imim = inputs[num].cpu().data.numpy().swapaxes(0,2) - img.cpu().data.numpy().swapaxes(0,2)  
        imim = inputs[num].cpu().data.numpy().swapaxes(0,2)   - img.cpu().data.numpy().swapaxes(0,2)  
        imim = np.clip(a=imim, a_min=0, a_max=1)
        
        #imim = np.asarray(imim, dtype = int)
        plt.imshow(imim)

In [None]:
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
for i, (inputs, _) in enumerate(dataloader, 0):
    a = model(inputs.to(device))
    fig=plt.figure(figsize=(24, 12))
    
    for num,img in enumerate(a[0]):
    
        fig.add_subplot(1,len(a[0]), num+1)
        
        #imim = inputs[num].cpu().data.numpy().swapaxes(0,2) - img.cpu().data.numpy().swapaxes(0,2)  
        imim = inputs[num].cpu().data.numpy() - img.cpu().data.numpy() 
        for i in range(3):
            a = imim[i]-img.cpu().data.numpy()[i]
            plt.figure()
            sns.heatmap(a, cmap='bwr')
            plt.show()
        print("#################")
        
    break

In [None]:
for i, (inputs, _) in enumerate(test_dataloader, 0):
    a = model(inputs.to(device))
    fig=plt.figure(figsize=(24, 12))
    
    for num,img in enumerate(a[0]):
    
        fig.add_subplot(1,len(a[0]), num+1)
        
        #imim = inputs[num].cpu().data.numpy().swapaxes(0,2) - img.cpu().data.numpy().swapaxes(0,2)  
        imim = inputs[num].cpu().data.numpy() - img.cpu().data.numpy() 
        
        for i in range(3):
            a = imim[i]-img.cpu().data.numpy()[i]
            plt.figure()
            """
            sns.heatmap(imim[i], cmap='bwr')
            plt.show()
            sns.heatmap(img.cpu().data.numpy()[i], cmap='bwr')
            plt.show()"""
            sns.heatmap(a, cmap='bwr')
            plt.show()
        print("#################")


In [None]:
imim.shape

In [None]:
from alibi_detect.od import OutlierVAE