In [1]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import pandas as pd
import random
import os
from tqdm import tqdm
import torchvision.transforms as T
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data import ConcatDataset
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, OneCycleLR, CosineAnnealingWarmRestarts, StepLR
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedKFold
import cv2
import wandb

In [3]:
wandb.init(
    project="ensemble CLIP-ViT & EfficientNetB3 & Swin & ResNet50_ROP",

    config={
    "scheduler": "ReduceLROnPlateau",
    "optimizer" : "AdamW",
    "optimizer-lr" : 5e-4,
    "optimizer-weight_decay" : 1e-4,
    "scheduler-step_size": 3,
    "scheduler-gamma":0.1,
    "architecture": "CLIP-ViT & EfficientNetB3 & Swin & ResNet50",
    "dataset": "SketchNet",
    "epochs": 5,
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33msuperl3[0m ([33msuperl3-naver[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_base_dir = '/data/ephemeral/home/data/train'
test_base_dir = '/data/ephemeral/home/data/test'


traindata_info_file = "/data/ephemeral/home/data/train.csv"

testdata_info_file = "/data/ephemeral/home/data/test.csv"

train_data = pd.read_csv(traindata_info_file)

test_data = pd.read_csv(testdata_info_file)

x = train_data['image_path']
y = train_data['target']

train_data.head(3)

Unnamed: 0,class_name,image_path,target
0,n01872401,n01872401/sketch_50.JPEG,59
1,n02417914,n02417914/sketch_11.JPEG,202
2,n02106166,n02106166/sketch_3.JPEG,138


In [5]:
class Train():
    def __init__(self, model, device, train_loader, val_loader, epochs, optimizer, criterion, scheduler, 
                 early_stop = False, patience_limit = None, best_val_loss = float('inf'), best_model = None):
        self.model = model
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.epochs = epochs
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.early_stop = early_stop
        self.patience_limit = patience_limit
        self.best_val_loss = best_val_loss
        self.best_model = best_model

    def train(self):
        patience_check = 0
        self.model.to(self.device)
        for epoch in range(self.epochs):
            running_loss = 0.0

            # 모델 학습
            torch.cuda.empty_cache()
            self.model.train()
            for idx, (images, labels) in enumerate(tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.epochs}")):
                images, labels = images.to(self.device), labels.to(self.device)
                
                outputs = self.model(images)
                
                loss = self.criterion(outputs, labels)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                running_loss += loss.item()
            
            train_loss = running_loss / len(self.train_loader)

            # 모델 평가
            torch.cuda.empty_cache()
            self.model.eval()
            correct = 0
            total = 0
            running_val_loss = 0.0
            with torch.no_grad():
                for inputs, labels in self.val_loader:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels)
                    running_val_loss += loss.item()
                    _, pred = torch.max(outputs, 1)
                    
                    total += labels.size(0)
                    correct += (pred == labels).sum().item()

                accuracy = 100 * correct / total
                val_loss = running_val_loss / len(self.val_loader)
                
            print(f'Epoch {epoch + 1}/{self.epochs}, Test_Loss: {train_loss:.4f}, Val_Loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%')
            wandb.log({"acc": accuracy, "loss": val_loss})
            
            if self.best_val_loss > val_loss:
                self.best_val_loss = val_loss
                self.best_model = self.model.state_dict()
                torch.save(self.model.state_dict(), '/data/ephemeral/home/Dongook/model_dw/ens4_rop.pt')

            # 조기 종료
            if self.early_stop and self.patience_limit is not None:
                if val_loss > self.best_val_loss:
                    patience_check += 1
                    if patience_check >= self.patience_limit:
                        break
                else: 
                    self.best_val_loss = val_loss
                    patience_check = 0
                    torch.save(self.model.state_dict(), '/data/ephemeral/home/Dongook/model_dw/ens4_rop.pt')

            self.scheduler.step(val_loss)

In [6]:
transform = A.Compose([
        A.Resize(224,224),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

In [7]:
def fill_white(img):
    img_np = np.array(img)

    # 각 채널에 대해 Canny 엣지 검출 적용
    edges_r = cv2.Canny(img_np[:, :, 0], 50, 150)
    edges_g = cv2.Canny(img_np[:, :, 1], 50, 150)
    edges_b = cv2.Canny(img_np[:, :, 2], 50, 150)

    # 세 채널의 엣지를 결합하여 하나의 이미지로 생성
    edges_combined = np.maximum(np.maximum(edges_r, edges_g), edges_b)

    # 모폴로지 연산을 위한 커널 생성
    kernel = np.ones((3, 3), np.uint8)  
    # 엣지 이미지를 닫기 위한 모폴로지 연산 적용
    closed_edges = cv2.morphologyEx(edges_combined, cv2.MORPH_CLOSE, kernel)

    filled_image_rgb = cv2.cvtColor(closed_edges, cv2.COLOR_GRAY2RGB)

    inverted_edges = cv2.bitwise_not(closed_edges)

    # 흰색 배경 이미지 생성
    filled_image_rgb = np.full_like(img_np, 255)  
    # 배경이 흰색인 상태에서 검은색 선을 포함한 inverted_edges를 복사
    filled_image_rgb[inverted_edges == 0] = [0, 0, 0] 

    transform = A.Compose([
        A.Affine(scale=(0.5, 1.5), p=0.5),
        A.CoarseDropout(max_holes=4, max_height=30, max_width=30, fill_value=255, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.LongestMaxSize(256),
        A.PadIfNeeded(256, 256, border_mode=cv2.BORDER_CONSTANT, value=(255, 255, 255)),
        A.RandomCrop(224, 224,p=1.0),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    transformed = transform(image=filled_image_rgb)
    
    return transformed['image']


In [8]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, is_aug = False):
        self.image_paths = image_paths
        self.labels = labels
        self.is_aug = is_aug

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(os.path.join(train_base_dir, image_path)).convert('RGB')
        if not self.is_aug:
            image = np.array(image)
            image = transform(image=image)['image']
        else:
            image = fill_white(image)
        
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

In [9]:
x_train = []
y_train = []
for i in range(len(train_data)):
    x_train.append(os.path.join(train_base_dir, train_data['image_path'].iloc[i]))
    y_train.append(train_data['target'].iloc[i])

In [10]:
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size = 0.2, stratify= y_train, random_state = 42)

print(len(train_data))
print(len(x_train))
print(len(x_val))

15021
12016
3005


In [11]:
train_dataset = CustomDataset(x_train, y_train, is_aug = False)
aug_dataset = CustomDataset(x_train, y_train, is_aug = True)
val_dataset = CustomDataset(x_val, y_val, is_aug = False)

print(len(train_dataset))
print(len(aug_dataset))
print(len(val_dataset))

dataset = ConcatDataset([train_dataset, aug_dataset])
print(len(dataset))

12016
12016
3005
24032


In [12]:
train_loader = DataLoader(dataset, 32, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, 32, num_workers=2, shuffle=False)

In [6]:
from transformers import CLIPModel

clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = clip.vision_model

class Clip(nn.Module):
    def __init__(self, image_encoder):
        super(Clip, self).__init__()  
        self.clip = image_encoder
        self.mlp = nn.Sequential(
            nn.Linear(1024, 256), 
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 500)
        )

    def forward(self, images):
        with torch.no_grad():
            image_features = self.clip(images).last_hidden_state
        
        y = self.mlp(image_features[:, 0, :])
        return y
    
clip_model = Clip(image_encoder)
clip_model.load_state_dict(torch.load('/data/ephemeral/home/Dongook/model_dw/Clip_rop.pt'))

for param in clip_model.parameters():
    param.requires_grad = False

<All keys matched successfully>

In [7]:
cnn_model = timm.create_model('tf_efficientnet_b3', pretrained = False, num_classes = 500)
cnn_model.load_state_dict(torch.load('/data/ephemeral/home/Dongook/model_dw/effib3_83.79.pt'))

for param in cnn_model.parameters():
    param.requires_grad = False

<All keys matched successfully>

In [8]:
swin_model = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=500)
swin_model.load_state_dict(torch.load('/data/ephemeral/home/Dongook/model_dw/swin_feed.pt'))

for param in swin_model.parameters():
    param.requires_grad = False

<All keys matched successfully>

In [9]:
res_model = timm.create_model('resnet50d', pretrained = False, num_classes = 500)
res_model.load_state_dict(torch.load('/data/ephemeral/home/Dongook/model_dw/res50d_e120.pt'))

for param in res_model.parameters():
    param.requires_grad = False

<All keys matched successfully>

In [4]:
class AdaptiveEnsemble(nn.Module):
    def __init__(self, clip_model, cnn_model, swin_model, res_model):
        super().__init__()
        self.clip = clip_model
        self.cnn = cnn_model
        self.swin = swin_model
        self.res = res_model

        self.clip_weight = nn.Parameter(torch.tensor(0.3))
        self.cnn_weight = nn.Parameter(torch.tensor(0.2))
        self.swin_weight = nn.Parameter(torch.tensor(0.3))
        self.res_weight = nn.Parameter(torch.tensor(0.2))

    def forward(self, x):
        clip_y = self.clip(x)
        cnn_y = self.cnn(x)
        swin_y = self.swin(x)
        res_y = self.res(x)
        ensemble_y = self.clip_weight * clip_y + self.cnn_weight * cnn_y + self.swin_weight * swin_y + self.res_weight * res_y
        
        return ensemble_y

In [18]:
ensemble_model = AdaptiveEnsemble(clip_model, cnn_model, swin_model, res_model)

epochs = 5
criterion = nn.CrossEntropyLoss(label_smoothing = 0.05)
optimizer = torch.optim.AdamW([ensemble_model.weights], lr=5e-4, weight_decay=1e-4)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
trainer = Train(ensemble_model, device = device, train_loader = train_loader, val_loader = val_loader, epochs = epochs,
                    optimizer = optimizer, criterion = criterion, scheduler = scheduler, early_stop= True, patience_limit=3)
trainer.train()
wandb.finish()

Epoch 1/5: 100%|██████████| 751/751 [12:27<00:00,  1.00it/s]


Epoch 1/5, Val_Loss: 0.8580, Val_Loss: 1.0075, Accuracy: 89.62%


Epoch 2/5: 100%|██████████| 751/751 [12:26<00:00,  1.01it/s]


Epoch 2/5, Val_Loss: 0.7794, Val_Loss: 1.0018, Accuracy: 89.18%


Epoch 3/5: 100%|██████████| 751/751 [12:26<00:00,  1.01it/s]


Epoch 3/5, Val_Loss: 0.7310, Val_Loss: 1.0085, Accuracy: 88.89%


Epoch 4/5: 100%|██████████| 751/751 [12:27<00:00,  1.01it/s]


Epoch 4/5, Val_Loss: 0.7068, Val_Loss: 1.0161, Accuracy: 88.35%


Epoch 5/5: 100%|██████████| 751/751 [12:26<00:00,  1.01it/s]


Epoch 5/5, Val_Loss: 0.6953, Val_Loss: 1.0221, Accuracy: 87.92%


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
acc,█▆▅▃▁
loss,▃▁▃▆█

0,1
acc,87.92013
loss,1.02206


In [19]:
class TestDataset(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(os.path.join(test_base_dir, image_path)).convert('RGB')
        image = np.array(image)
        image = transform(image=image)['image']
        return image
test_dataset = TestDataset(list(test_data['image_path']))

In [20]:
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
print(len(test_loader))

313


In [24]:
basic_ens_model = AdaptiveEnsemble(clip_model, cnn_model, swin_model, res_model)
basic_ens_model.load_state_dict(torch.load('/data/ephemeral/home/Dongook/model_dw/ens4_rop.pt'))

basic_ens_model.to(device)
basic_ens_model.eval()
    
predictions = []
with torch.no_grad():  
    for images in tqdm(test_loader):
        images = images.to(device)
            
        logits = basic_ens_model(images)
        logits = F.softmax(logits, dim=1)
        preds = logits.argmax(dim=1)
            
        predictions.extend(preds.cpu().detach().numpy())

  0%|          | 0/313 [00:00<?, ?it/s]

100%|██████████| 313/313 [03:56<00:00,  1.32it/s]


In [25]:
test_data8 = pd.read_csv(testdata_info_file)

test_data8['target'] = predictions
test_data8 = test_data8.reset_index().rename(columns={"index": "ID"})
test_data8

Unnamed: 0,ID,image_path,target
0,0,0.JPEG,328
1,1,1.JPEG,414
2,2,2.JPEG,493
3,3,3.JPEG,17
4,4,4.JPEG,388
...,...,...,...
10009,10009,10009.JPEG,235
10010,10010,10010.JPEG,191
10011,10011,10011.JPEG,466
10012,10012,10012.JPEG,258


In [26]:
test_data8.to_csv("/data/ephemeral/home/Dongook/output/output_basic_output_ens4.csv", index=False)
test_data8

Unnamed: 0,ID,image_path,target
0,0,0.JPEG,328
1,1,1.JPEG,414
2,2,2.JPEG,493
3,3,3.JPEG,17
4,4,4.JPEG,388
...,...,...,...
10009,10009,10009.JPEG,235
10010,10010,10010.JPEG,191
10011,10011,10011.JPEG,466
10012,10012,10012.JPEG,258
