In [1]:
# 클래스 ID 및 레이블 매핑 정의

id2label = {
            0: 'background',
            1: 'common_road',
            2: 'common_tree',
            3: 'field_corps',
            4: 'field_furrow',
            5: 'field_levee',
            6: 'orchard_road',
            7: 'orchard_tree',
            8: 'paddy_after_driving',
            9: 'paddy_before_driving',
            10: 'paddy_edge',
            11: 'paddy_rice',
            12: 'paddy_water' 
        }


# 반대 매핑 생성
label2id = {v: k for k, v in id2label.items()}

In [2]:
import json
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import torch
import torchvision.transforms as T
from skimage.draw import polygon2mask
import cv2

class CustomDataset(Dataset):
    def __init__(self, img_dir, ann_dir, transforms=None):
        self.img_dir = img_dir
        self.ann_dir = ann_dir
        self.transforms = transforms
        
        
        
        self.label2id = {v: k for k, v in id2label.items()}  # 레이블을 ID로 매핑

        self.label_list = os.listdir(self.ann_dir)
        self.img_paths = []
        self.ann_paths = []
        
        # 미리 경로 계산
        for label in self.label_list:
            with open(os.path.join(self.ann_dir, label), 'r', encoding='utf-8') as f:
                img_info = json.load(f)
            img_path = os.path.join(self.img_dir, img_info['name'])
            self.img_paths.append(img_path)
            self.ann_paths.append(os.path.join(self.ann_dir, label))

    def __len__(self):
        return len(self.label_list)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        ann_path = self.ann_paths[idx]
        
        with open(ann_path, 'r', encoding='utf-8') as f:
            img_info = json.load(f)
        
        image = Image.open(img_path).convert('RGB')
        mask = np.full((image.height, image.width), 0, dtype=np.int32)  # 초기값을 0으로 설정
        
        for obj in img_info['objects']:
            class_id = self.get_class_index(obj['label'])
            for pos in obj['position']:
                coords = [(y, x) for x, y in zip(pos[::2], pos[1::2])] #pos[::2]는 짝수 인덱스(모든 x 좌표), pos[1::2]는 홀수 인덱스(모든 y 좌표)
                # print(f"Coords for {obj['label']} with class ID {class_id}: {coords}")  # 디버깅용 좌표 출력

                 # 좌표가 이미지 경계를 벗어나는지 확인
                out_of_bounds_coords = [(x, y) for y, x in coords if x < 0 or x >= image.width or y < 0 or y >= image.height]
                if out_of_bounds_coords:
                    #print(f"Warning: Some coordinates for {obj['label']} are out of image bounds: {out_of_bounds_coords}")
                    coords = [(max(0, min(image.height - 0.1, y)), max(0, min(image.width - 0.1, x))) for y, x in coords] # 이미지 벗어나는 좌표 클리핑

            

                poly_mask = polygon2mask((image.height, image.width), coords)

                # 디버깅용으로 poly_mask가 True인 위치 출력
                # true_indices = np.where(poly_mask)
                # true_positions = list(zip(true_indices[0], true_indices[1]))
                # print(f"True positions in poly_mask: {true_positions[:10]}")  # 처음 10개의 위치만 출력

                mask[poly_mask] = class_id

                # # 개별 마스크 시각화(디버깅)
                # plt.figure(figsize=(6, 6))
                # plt.imshow(mask, cmap='gray')
                # plt.title(f"{obj['label']} with class ID {class_id}")
                # plt.show()
                
        if self.transforms:
            augmented = self.transforms(image=np.array(image), mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        

        return image, mask

        # if self.transforms:
        #     image = self.transforms(image)
        
        # mask = torch.tensor(mask, dtype=torch.long)

        # return image, mask

    def get_class_index(self, label):
        return self.label2id.get(label, 0)  # 기본값 0 (배경)

In [3]:
from torch.utils.data import DataLoader
import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2 

# ### Data augmentation을 고려할 필요가 있음

# train_transforms = T.Compose([
#     T.ToTensor(),
#     T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# val_transforms = T.Compose([
#     T.ToTensor(),
#     T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])


train_transforms = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.ColorJitter(p=0.5),
    A.GaussianBlur(blur_limit=3, p=0.1),
    A.GaussNoise(p=0.2),
    A.OneOf([
        A.MotionBlur(p=0.2),
        A.MedianBlur(blur_limit=3, p=0.1),
    ], p=1.0),
    A.OneOf([
        A.RandomRain(p=0.2),
        A.RandomFog(p=0.2),
        A.RandomShadow(p=0.2),
    ], p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transforms = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])


# DataLoader 설정
train_dataset = CustomDataset("C:/Users/USER/Desktop/resized_train_images","C:/Users/USER/Desktop/resized_train_annotations", transforms=train_transforms)
val_dataset = CustomDataset("C:/Users/USER/Desktop/resized_valid_images", "C:/Users/USER/Desktop/resized_valid_annotations", transforms=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=13, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=13, shuffle=False, num_workers=0)


  self.__pydantic_validator__.validate_python(data, self_instance=self)


In [4]:
# start_epoch = 6  # 시작 에폭 설정
# num_epochs = 10  # 총 학습 에폭 설정
import wandb
wandb.init(project="knowledge_distillation", id='b3_b2')

# # wandb 초기화
# wandb.init(project="uncategorized_project", resume="allow", id="lemon-cherry-45")
# wandb.config.update({
#     "epochs": num_epochs,
#     "batch_size": 16,
#     "start_epoch": start_epoch,
#     "learning_rate": 0.00006,
# })

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchldyddn98[0m ([33mchoiyw[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
import evaluate
import os
from transformers import SegformerForSemanticSegmentation

# 저장할 폴더 지정 및 생성
model_save_path = "saved_models"
model_filename = "best_model.pth"
if not os.path.exists(model_save_path):
    os.makedirs(model_save_path)
full_model_path = os.path.join(model_save_path, model_filename)

# 설정 변수
gradient_accumulation_steps = 4
best_mean_iou = 0.0

# Mean IoU 계산 함수 정의
def compute_mean_iou(predictions, references, num_labels, ignore_index):
    classes_in_labels = torch.unique(references)
    if ignore_index is not None:
        classes_in_labels = classes_in_labels[classes_in_labels != ignore_index]
    
    iou_list = []
    for cls in range(num_labels):
        if cls in classes_in_labels:
            intersection = torch.sum((predictions == cls) & (references == cls)).item()
            union = torch.sum((predictions == cls) | (references == cls)).item()
            iou = intersection / union if union != 0 else 0
            iou_list.append(iou)
    
    mean_iou = sum(iou_list) / len(iou_list) if iou_list else 0
    return mean_iou

# 검증 함수
def validate_amp(model, val_loader, device, num_labels=13, ignore_index=255):
    model.eval()  # 모델을 평가 모드로 전환
    total_loss = 0
    total_correct = 0
    total_pixels = 0
    iou_metric = evaluate.load("mean_iou", trust_remote_code=True)  # 각 검증 루프마다 메트릭을 새로 초기화
    
    with torch.no_grad():  
        pbar = tqdm(total=len(val_loader), desc="Validating", unit="batch")  # tqdm을 사용하여 검증 진행 상황 표시
        for batch in val_loader:
            pixel_values = batch[0].to(device).float()  # 입력 데이터를 장치로 이동
            labels = batch[1].to(device).long()  # 라벨 데이터를 장치로 이동
            
            with autocast():  # 자동 혼합 정밀도(AMP)를 사용하여 연산
                outputs = model(pixel_values=pixel_values)  
                logits = outputs.logits  # 로짓 값
                upsampled_logits = F.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)  # 로짓을 라벨 크기로 보간
                loss = F.cross_entropy(upsampled_logits, labels, ignore_index=ignore_index)  # 교차 엔트로피 손실 계산
                
            total_loss += loss.item()  # 총 손실 계산
            predicted = upsampled_logits.argmax(dim=1)  # 예측값 계산

            # Pixel accuracy 계산
            correct = (predicted == labels).sum().item()
            total = (labels != ignore_index).sum().item()
            total_correct += correct
            total_pixels += total

            iou_metric.add_batch(predictions=predicted, references=labels)  # 메트릭에 배치 추가
            pbar.update(1) 

        pbar.close()  

    avg_loss = total_loss / len(val_loader)  # 평균 손실 계산
    mean_iou = compute_mean_iou(predicted, labels, num_labels, ignore_index)  # Mean IoU 계산
    pixel_accuracy = total_correct / total_pixels  # Pixel Accuracy 계산
    return avg_loss, mean_iou, pixel_accuracy  # 평균 손실과 Mean IoU 반환

# Knowledge Distillation 손실 함수 정의
def distillation_loss(student_outputs, teacher_outputs, labels, T=2.0, alpha=0.5):
    soft_teacher_outputs = F.softmax(teacher_outputs / T, dim=1)
    soft_student_outputs = F.log_softmax(student_outputs / T, dim=1)
    loss_soft = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction='batchmean') * (T * T)
    loss_hard = F.cross_entropy(student_outputs, labels)
    return loss_soft * alpha + loss_hard * (1.0 - alpha)

# 학습 및 검증 함수 정의
def train_student_model(teacher_model, student_model, train_loader, val_loader, optimizer, num_epochs=10):
    best_mean_iou = 0.0
    gradient_accumulation_steps = 4
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(num_epochs):
        student_model.train()
        train_loss = 0
        pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

        for idx, batch in enumerate(train_loader):
            pixel_values = batch[0].to(device).float()
            labels = batch[1].to(device).long()

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    teacher_outputs = teacher_model(pixel_values=pixel_values).logits
                    teacher_outputs_upsampled = F.interpolate(teacher_outputs, size=labels.shape[-2:], mode="bilinear", align_corners=False)

                student_outputs = student_model(pixel_values=pixel_values).logits

                # 학생 모델의 출력을 라벨 크기로 맞추기
                student_outputs_upsampled = F.interpolate(student_outputs, size=labels.shape[-2:], mode="bilinear", align_corners=False)

                loss = distillation_loss(student_outputs_upsampled, teacher_outputs_upsampled, labels) / gradient_accumulation_steps

            scaler.scale(loss).backward()

            if (idx + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_loss += loss.item() * gradient_accumulation_steps
            pbar.update(1)
            pbar.set_postfix(batch_loss=f"{loss.item():.4f}")

        pbar.close()

        val_loss, mean_iou, pixel_accuracy = validate_amp(student_model, val_loader, device)
        wandb.log({"Validation Loss": val_loss, "Mean IOU": mean_iou, "Pixel Accuracy": pixel_accuracy})

        print(f"Validation Loss: {val_loss:.4f}, Mean IOU: {mean_iou}, Pixel Accuracy: {pixel_accuracy}")

        # 모델 저장
        if mean_iou > best_mean_iou:
            best_mean_iou = mean_iou
            torch.save(student_model.state_dict(), full_model_path)
            print(f"Saved best model to {full_model_path}")

# 모델 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model_name = "nvidia/segformer-b3-finetuned-cityscapes-1024-1024"
student_model_name = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"  # 경량화된 모델

teacher_model = SegformerForSemanticSegmentation.from_pretrained(teacher_model_name)
teacher_model.config.num_labels = 13
teacher_model.decode_head.classifier = torch.nn.Conv2d(768, teacher_model.config.num_labels, kernel_size=1)
teacher_model.load_state_dict(torch.load("C:/Users/USER/Desktop/yong/deep_learning_bootcamp/image_project2/saved_models/b3_10epoch_512288.pth", map_location=device))
teacher_model.eval()
teacher_model.to(device)

student_model = SegformerForSemanticSegmentation.from_pretrained(student_model_name)
student_model.config.num_labels = 13
student_model.decode_head.classifier = torch.nn.Conv2d(768, student_model.config.num_labels, kernel_size=1)
student_model.to(device)



# 학습 및 검증
optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.00006)
train_student_model(teacher_model, student_model, train_loader, val_loader, optimizer, num_epochs=10)


  from .autonotebook import tqdm as notebook_tqdm
INFO:datasets:PyTorch version 2.3.1 available.
Epoch 1/10:   8%|▊         | 69/831 [01:35<17:53,  1.41s/batch, batch_loss=161092.6094]