In [9]:
import json
import torch
from datasets.custom_dataset import CustomDataset
from datasets.transform import TransformSelector
from models.model_selector import ModelSelector
from utils.train_utils import Trainer
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
import os
from tqdm import tqdm

import matplotlib
import matplotlib.pyplot as plt

In [10]:
# 추론(inference) 함수
def inference(model, device, test_loader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for images in tqdm(test_loader, desc="Inference"):
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            predictions.extend(preds.cpu().numpy())
    return predictions

# 'best_model' 파일을 찾는 함수 (가장 최근에 저장된 파일 선택)
def get_best_model_path(directory):
    files = [f for f in os.listdir(directory) if f.startswith('best_model') and f.endswith('.pt')]
    if not files:
        raise FileNotFoundError(f"No best model files found in directory: {directory}")
    
    # 파일의 수정 시간을 기준으로 가장 최근에 저장된 파일 선택
    best_file = max(files, key=lambda f: os.path.getmtime(os.path.join(directory, f)))
    return os.path.join(directory, best_file)

def validate(model, device, val_loader):
        model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        loss_fn = nn.CrossEntropyLoss(label_smoothing=0.08)
        total_batches = len(val_loader)
        
        progress_bar = tqdm(val_loader, desc="Validating", leave=False)
        
        with torch.no_grad():
            for batch_idx, (images, targets, indices) in enumerate(progress_bar):
                images, targets = images.to(device), targets.to(device)
                
                # 모델 예측
                outputs = model(images)
                outputs_softmax = torch.nn.functional.softmax(outputs)
                
                # 손실 계산
                loss = loss_fn(outputs, targets)
                total_loss += loss.item()
                
                # 정확도 계산
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == targets).sum().item()
                total += targets.size(0)
                
                progress_bar.set_postfix({'val_batch_loss': loss.item()})

                # 이미지 그리기
                top_fives, top_fives_indices = torch.topk(outputs_softmax, 5)

                images = images.cpu().numpy()
                targets = targets.cpu().numpy()
                indices = indices.cpu().numpy()
                outputs_softmax = outputs_softmax.cpu().numpy()
                top_fives = top_fives.cpu().numpy()
                top_fives_indices = top_fives_indices.cpu().numpy()

                for i, index in enumerate(indices):
                     

        
        # 검증의 평균 손실과 정확도 계산
        avg_loss = total_loss / total_batches
        accuracy = correct / total
        
        print(f"Validation Epoch Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        
        return avg_loss, accuracy

IndentationError: expected an indented block after 'for' statement on line 63 (3701377864.py, line 68)

In [3]:
config_path = r"/data/ephemeral/home/Dongjin/git/level1-imageclassification-cv-07/LV1/config_dj1.json"
with open(config_path, 'r') as f:
  config = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_info = pd.read_csv(config['data_info_file'])

py_dir_path = r"/data/ephemeral/home/Dongjin/git/level1-imageclassification-cv-07/Model/0921_valid_analysis"
rel_train_index_path = r"datasets/train_index.csv"
rel_val_index_path = r"datasets/val_index.csv"

train_index_path = os.path.join(py_dir_path, rel_train_index_path)
val_index_path = os.path.join(py_dir_path, rel_val_index_path)

# train_index.csv와 val_index.csv를 이용하여 train_df와 val_df를 로드       
train_index = pd.read_csv(train_index_path, header = None).squeeze()
val_index = pd.read_csv(val_index_path, header = None).squeeze()

train_df = train_info.loc[train_index]
val_df = train_info.loc[val_index]

transform_selector = TransformSelector(transform_type="albumentations")
val_transform = transform_selector.get_transform(is_train=False)

val_dataset = CustomDataset(root_dir=config['train_data_dir'], info_df=val_df, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

# 모델 설정
model_selector = ModelSelector(model_type="timm", num_classes=config['num_classes'], model_name=config['model_name'], pretrained=False)
model = model_selector.get_model()

# 베스트 모델 경로 설정
model_path = get_best_model_path(config['result_path'])
print(f"Loading best model from {model_path}")

# 저장된 모델 로드
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

Loading best model from /data/ephemeral/home/Dongjin/git/level1-imageclassification-cv-07/LV1/results/coatnet_2_rw_224_1/best_model_1.2901.pt


TimmModel(
  (model): MaxxVit(
    (stem): Stem(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm1): BatchNormAct2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): SiLU(inplace=True)
      )
      (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (stages): Sequential(
      (0): MaxxVitStage(
        (blocks): Sequential(
          (0): MbConvBlock(
            (shortcut): Downsample2d(
              (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
              (expand): Identity()
            )
            (pre_norm): BatchNormAct2d(
              128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
              (drop): Identity()
              (act): SiLU(inplace=True)
            )
            (down): Identity()
            (conv1_1x1): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bia

In [12]:
# 추론 실행
validate(model, device, val_loader)

KeyboardInterrupt: 

In [14]:
model.eval()
total_loss = 0.0
correct = 0
total = 0

loss_fn = nn.CrossEntropyLoss(label_smoothing=0.08)
total_batches = len(val_loader)

progress_bar = tqdm(val_loader, desc="Validating", leave=False)

with torch.no_grad():
    for batch_idx, (images, targets, indices) in enumerate(progress_bar):
        images, targets = images.to(device), targets.to(device)
        
        # 모델 예측
        outputs = model(images)
        outputs_softmax = torch.nn.functional.softmax(outputs)
        
        # 손실 계산
        loss = loss_fn(outputs, targets)
        total_loss += loss.item()
        
        # 정확도 계산
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == targets).sum().item()
        total += targets.size(0)
        
        progress_bar.set_postfix({'val_batch_loss': loss.item()})

        # 이미지 그리기
        top_fives, top_fives_indices = torch.topk(outputs_softmax, 5)

        images = images.cpu().numpy()
        targets = targets.cpu().numpy()
        indices = indices.cpu().numpy()
        outputs_softmax = outputs_softmax.cpu().numpy()
        top_fives = top_fives.cpu().numpy()
        top_fives_indices = top_fives_indices.cpu().numpy()
        predicted = predicted.cpu().numpy()

        for i, index in enumerate(indices):
            print(i)

  outputs_softmax = torch.nn.functional.softmax(outputs)
Validating:   1%|          | 1/94 [00:00<00:26,  3.46it/s, val_batch_loss=1.29]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:   2%|▏         | 2/94 [00:00<00:26,  3.48it/s, val_batch_loss=1.85]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:   3%|▎         | 3/94 [00:00<00:25,  3.52it/s, val_batch_loss=1.67]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:   4%|▍         | 4/94 [00:01<00:25,  3.56it/s, val_batch_loss=1.12]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:   5%|▌         | 5/94 [00:01<00:24,  3.56it/s, val_batch_loss=1.49]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:   6%|▋         | 6/94 [00:01<00:24,  3.54it/s, val_batch_loss=1.15]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:   7%|▋         | 7/94 [00:01<00:24,  3.62it/s, val_batch_loss=1.03]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:   9%|▊         | 8/94 [00:02<00:23,  3.60it/s, val_batch_loss=1.31]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  10%|▉         | 9/94 [00:02<00:23,  3.58it/s, val_batch_loss=1.28]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  11%|█         | 10/94 [00:02<00:23,  3.62it/s, val_batch_loss=1.02]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  12%|█▏        | 11/94 [00:03<00:23,  3.56it/s, val_batch_loss=1.41]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  13%|█▎        | 12/94 [00:03<00:23,  3.56it/s, val_batch_loss=1.51]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  14%|█▍        | 13/94 [00:03<00:24,  3.33it/s, val_batch_loss=1.62]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  15%|█▍        | 14/94 [00:04<00:23,  3.35it/s, val_batch_loss=1.38]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  16%|█▌        | 15/94 [00:04<00:23,  3.33it/s, val_batch_loss=1.05]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  17%|█▋        | 16/94 [00:04<00:23,  3.31it/s, val_batch_loss=1.34]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  18%|█▊        | 17/94 [00:04<00:23,  3.34it/s, val_batch_loss=1.2] 

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  19%|█▉        | 18/94 [00:05<00:24,  3.13it/s, val_batch_loss=1.14]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  20%|██        | 19/94 [00:05<00:23,  3.18it/s, val_batch_loss=1.15]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  21%|██▏       | 20/94 [00:05<00:22,  3.32it/s, val_batch_loss=1.38]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  22%|██▏       | 21/94 [00:06<00:21,  3.42it/s, val_batch_loss=1.13]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  23%|██▎       | 22/94 [00:06<00:20,  3.48it/s, val_batch_loss=1.16]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  24%|██▍       | 23/94 [00:06<00:20,  3.54it/s, val_batch_loss=1.21]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  26%|██▌       | 24/94 [00:06<00:20,  3.41it/s, val_batch_loss=1.65]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  27%|██▋       | 25/94 [00:07<00:20,  3.39it/s, val_batch_loss=1.16]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  28%|██▊       | 26/94 [00:07<00:19,  3.40it/s, val_batch_loss=1.2] 

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  29%|██▊       | 27/94 [00:07<00:19,  3.38it/s, val_batch_loss=1.28]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  30%|██▉       | 28/94 [00:08<00:19,  3.41it/s, val_batch_loss=1.41]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  31%|███       | 29/94 [00:08<00:18,  3.51it/s, val_batch_loss=1.07]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  32%|███▏      | 30/94 [00:08<00:18,  3.53it/s, val_batch_loss=1.44]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  33%|███▎      | 31/94 [00:09<00:18,  3.46it/s, val_batch_loss=1.15]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  34%|███▍      | 32/94 [00:09<00:18,  3.39it/s, val_batch_loss=1.23]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  35%|███▌      | 33/94 [00:09<00:17,  3.42it/s, val_batch_loss=1.3] 

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  36%|███▌      | 34/94 [00:09<00:17,  3.45it/s, val_batch_loss=1.43]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  37%|███▋      | 35/94 [00:10<00:17,  3.46it/s, val_batch_loss=1.32]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  38%|███▊      | 36/94 [00:10<00:16,  3.41it/s, val_batch_loss=1.11]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  39%|███▉      | 37/94 [00:10<00:16,  3.41it/s, val_batch_loss=1.01]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  40%|████      | 38/94 [00:11<00:16,  3.48it/s, val_batch_loss=1.17]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


Validating:  41%|████▏     | 39/94 [00:11<00:15,  3.53it/s, val_batch_loss=1.4] 

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


                                                                                

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31




KeyboardInterrupt: 