In [2]:
import json
import torch
from datasets.custom_dataset import CustomDataset
from datasets.transform import TransformSelector
from models.model_selector import ModelSelector
from utils.train_utils import Trainer
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm
import cv2

import matplotlib
import matplotlib.pyplot as plt
import re

matplotlib.use('Agg')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def image_path_to_num(image_path):
    image_name = os.path.split(image_path)[-1]
    match = re.search(r'_(\d+)', image_name)
    image_number = int(match.group(1))
    return image_number

def create_fold(fold_path):
    if not os.path.isdir(fold_path):
        os.makedirs(fold_path)

'/data/ephemeral/home/Dongjin/git/level1-imageclassification-cv-07/Model/0922_valid_analysis'

In [58]:
current_work_dir = os.getcwd()
save_path_dir = os.path.join(current_work_dir, "result/class")
data_info_file = '/data/ephemeral/home/data/train.csv'
train_data_dir = '/data/ephemeral/home/data/train'
class_to_name_file = os.path.join(current_work_dir, "result/class/map_clsloc.txt")

train_info = pd.read_csv(data_info_file)

# class_name으로 category 불러오기
class_to_name = pd.read_csv(class_to_name_file, header = None, sep = " ")
class_to_name.columns = ["class_name", "index", "category"]
class_to_name = class_to_name.drop(["index"], axis = 1)

# train_info에 category merge하기
train_info = pd.merge(train_info, class_to_name, on = 'class_name', how = 'left')

groups = train_info.groupby('target')
create_fold(save_path_dir)

In [61]:
for i, group in train_info.groupby('target'):
    group["image_number"] = group["image_path"].apply(image_path_to_num)
    group = group.sort_values(by = ['image_number'])
    
    target = group.target.values[0]
    class_name = group.class_name.values[0]
    category = group.category.values[0]

    [fig, axes] = plt.subplots(6, 6, figsize = (12, 12))
    axes = axes.flatten()
    
    for j, row in enumerate(group.itertuples()):
        image_path = os.path.join(train_data_dir, row.image_path)

        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        axes[j].imshow(image)
        axes[j].set_title(row.image_number, fontsize = 14)
        
    for ax in axes:
        ax.axis('off')
    
    title = f'Target: {target}, Category: {category}, Class name: {class_name}'
    fig.suptitle(title, fontsize = 20)

    save_path = os.path.join(save_path_dir, f'{target}.png')
    plt.savefig(save_path)    