# Basic Settings

In [1]:
import random
import pandas as pd
import numpy as np
import os
import glob
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from datetime import datetime

import warnings
warnings.filterwarnings(action='ignore') 

In [2]:
CFG = {
    'RUN_TYPE': 'INFERENCE',
    'CLF_PATHS': './cnn_classifier.pth',
    'SEED': 42,
    'H_SHIFT_MAX': 2,
    'W_SHIFT_MAX': 2,
    'KNN_NUM': 1,
}

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Prepare Dataset

In [5]:
simulation_sem_paths = sorted(glob.glob('./processed_data/simulation/SEM/*.png'))
simulation_depth_paths = sorted(glob.glob('./processed_data/simulation/Depth/*.png'))

df_simul = pd.DataFrame({'SEM':simulation_sem_paths, 'depth':simulation_depth_paths})
df_simul['case'] = df_simul['SEM'].apply(lambda x: int(x[36]))

In [6]:
df_simul

Unnamed: 0,SEM,depth,case
0,./processed_data/simulation/SEM\SEM_1_00000.png,./processed_data/simulation/Depth\Depth_1_0000...,1
1,./processed_data/simulation/SEM\SEM_1_00001.png,./processed_data/simulation/Depth\Depth_1_0000...,1
2,./processed_data/simulation/SEM\SEM_1_00002.png,./processed_data/simulation/Depth\Depth_1_0000...,1
3,./processed_data/simulation/SEM\SEM_1_00003.png,./processed_data/simulation/Depth\Depth_1_0000...,1
4,./processed_data/simulation/SEM\SEM_1_00004.png,./processed_data/simulation/Depth\Depth_1_0000...,1
...,...,...,...
86647,./processed_data/simulation/SEM\SEM_4_86647.png,./processed_data/simulation/Depth\Depth_4_8664...,4
86648,./processed_data/simulation/SEM\SEM_4_86648.png,./processed_data/simulation/Depth\Depth_4_8664...,4
86649,./processed_data/simulation/SEM\SEM_4_86649.png,./processed_data/simulation/Depth\Depth_4_8664...,4
86650,./processed_data/simulation/SEM\SEM_4_86650.png,./processed_data/simulation/Depth\Depth_4_8665...,4


In [7]:
class Clf_Dataset(Dataset):
  def __init__(self, sem_path_list):
    self.sem_path_list = sem_path_list
  def __len__(self):
    return len(self.sem_path_list)
  def __getitem__(self, idx):
    img = torch.Tensor(cv2.imread(self.sem_path_list[idx], cv2.IMREAD_GRAYSCALE))[None, :] / 255
    return img

# Case Classifier Model

In [8]:
!pip install timm



In [9]:
import timm # timm 설치 필요

class Case_Classifier(nn.Module):
    def __init__(self):
        super(Case_Classifier, self).__init__()
        #tf_efficientnet_b0_ns
        self.model = timm.create_model('tf_efficientnet_b0_ns', pretrained = True, num_classes = 4, in_chans=1)
        self.softmax = nn.Softmax()
        
    def forward(self, x):
        x = self.model(x)
        return self.softmax(x)

# Case_classification

In [10]:
test_sem_path_list = sorted(glob.glob('./test/SEM/*.png'))

clf_set = Clf_Dataset(test_sem_path_list)
clf_loader = DataLoader(clf_set, batch_size=64, shuffle=False)

In [11]:
classifier = Case_Classifier().to(device).eval()
classifier.load_state_dict(torch.load(CFG['CLF_PATHS']))

<All keys matched successfully>

In [12]:
case_list = []
with torch.no_grad():
  for img in tqdm(iter(clf_loader)):
    img = img.to(device)
    preds = classifier(img).argmax(dim=1) + 1
    for pred in preds:
      case_list.append(pred.cpu().numpy())

df_test = pd.DataFrame({'path':test_sem_path_list, 'case':case_list})
del img
del preds
del clf_loader

100%|████████████████████████████████████████████████████████████████████████████████| 407/407 [00:12<00:00, 31.95it/s]


In [13]:
for i in range(1,5):
  print(len(df_test[df_test['case']==i]))

6489
6499
6516
6484


# Get similarity

In [18]:
class Inference_key_dataset(Dataset):
    def __init__(self, df):
        self.df = df
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        img = torch.Tensor(cv2.imread(self.df.iloc[idx]['SEM'], cv2.IMREAD_GRAYSCALE))[None, :]
        avg = torch.mean(img)
        std = torch.std(img, unbiased=False)
        img = (img-avg) / std
        return img

class Inference_query_dataset(Dataset):
  def __init__(self, df):
    self.df = df
  def __len__(self):
    return len(self.df)
  def __getitem__(self, idx):
    sem_path = self.df.iloc[idx, 0]
    img = torch.Tensor(cv2.imread(sem_path, cv2.IMREAD_GRAYSCALE))[None, :]
    
    avg = torch.mean(img)
    std = torch.std(img, unbiased=False)
    img = (img-avg) / std
    img_name = sem_path.split('/')[-1]

    return img, img_name

In [19]:
def shift_img(img, h_shift, w_shift):
  h_shifted_img = torch.zeros(img.shape)
  shifted_img = torch.zeros(img.shape)
  if h_shift >= 0:
    h_shifted_img[:, :, 0:72-h_shift, :] = img[:, :, h_shift:72, :]
    h_shifted_img[:, :, 72-h_shift:72, :] = img[:, :, :h_shift, :]
  else:
    h_shift *= -1
    h_shifted_img[:, :, h_shift:72, :] = img[:, :, 0:72-h_shift, :]
    h_shifted_img[:, :, :h_shift, :] = img[:, :, 72-h_shift:72, :]

  if w_shift <  0:
    w_shift *= -1
    shifted_img[:, :, :, 0:48-w_shift] = h_shifted_img[:, :, :, w_shift:48]
    shifted_img[:, :, :, 48-w_shift:48] = h_shifted_img[:, :, :, :w_shift]
  else:
    shifted_img[:, :, :, w_shift:48] = h_shifted_img[:, :, :, 0:48-w_shift]
    shifted_img[:, :, :, :w_shift] = h_shifted_img[:, :, :, 48-w_shift:48]
  
  return shifted_img.to(device)


In [20]:
def get_batch_similarity(query, key):
  query_norm = torch.norm(query, dim=1)[:, None]
  key_norm = torch.norm(key, dim=1)[None, :]
  return torch.matmul(query, key.T) / query_norm / key_norm

def get_value_and_weight_sum(k_args, k_weights, df):
  pred = torch.zeros(72,48)
  for arg, weight in zip(k_args, k_weights):
    value_path = df.iloc[arg.item()]['depth']
    value = torch.Tensor(cv2.imread(value_path, cv2.IMREAD_GRAYSCALE)) * weight.item()
    pred += value
  return pred

softmax = nn.Softmax(dim=1)

In [None]:
from tqdm.contrib import itertools

result_list = []
result_name_list = []
with torch.no_grad():
  for case in range(1,5):
    print(f'calculating similarities for case_{case}')
    temp_simul = df_simul[df_simul['case']==case]
    temp_test = df_test[df_test['case']==case]
    simul_set = Inference_key_dataset(temp_simul)
    simul_loader = DataLoader(simul_set, batch_size=7000, shuffle=False)
    test_set = Inference_query_dataset(temp_test)
    test_loader = DataLoader(test_set, batch_size=14000, shuffle=False)

    for test_sem, img_name in test_loader:
      test_sem = test_sem.to(device)
      similarity_matrix = None
      for h_shift, w_shift in itertools.product(range(-CFG['H_SHIFT_MAX'], CFG['H_SHIFT_MAX']+1), range(-CFG['W_SHIFT_MAX'], CFG['W_SHIFT_MAX']+1)):
          shift_similarity = None
          shifted_test_query = shift_img(test_sem, h_shift, w_shift).to(device).view(-1, 3456)
          for simul_key in simul_loader:
            simul_key = simul_key.to(device).view(-1, 3456)
            similarity = get_batch_similarity(shifted_test_query, simul_key)
            if shift_similarity == None:
              shift_similarity = similarity
            else:
              shift_similarity = torch.cat((shift_similarity, similarity), dim=1)

          if similarity_matrix  == None:
            similarity_matrix = shift_similarity[:, :, None]
          else:
            similarity_matrix = torch.max(torch.cat((similarity_matrix, shift_similarity[:, :, None]), dim=-1), dim=-1, keepdim=True).values

      similarity_matrix = torch.squeeze(similarity_matrix)
      best_k_sim, best_k_args = torch.topk(similarity_matrix, k=CFG['KNN_NUM'], dim=1)
      best_k_weights = softmax(best_k_sim)

      for args, probs, id in zip(best_k_args, best_k_weights, img_name):
        pred = get_value_and_weight_sum(args, probs, temp_simul)
        result_name_list.append(id)
        result_list.append(pred[:,:,None])

In [None]:
import zipfile
test_id = 'Submission_F'

os.makedirs(f'./{test_id}/SEM', exist_ok=True)
os.chdir(f"./{test_id}")
sub_imgs = []
for path, pred_img in zip(result_name_list, result_list):
    cv2.imwrite(path, pred_img.numpy())
    sub_imgs.append(path[4:])
    
os.chdir(f"./SEM")
submission = zipfile.ZipFile(f"../{test_id}.zip", 'w')

for path in sub_imgs:
    submission.write(path)
submission.close()

os.chdir('../')
os.chdir('../')
os.getcwd()

# Submit Submission_F.zip file in Submission_F folder!
# or zip image in Submission_F folder for submission