<a href="https://colab.research.google.com/github/dlsqja/Machine-Learning/blob/main/Untitled2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [199]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from PIL import Image
import json
import os

In [202]:
class ImageJsonDataset(Dataset):
    def __init__(self, image_folder, json_folder, transform=None):
        self.image_folder = image_folder
        self.json_folder = json_folder
        self.transform = transform
        self.json_files = self.load_json_files()
        self.image_files = self.load_image_files()

    def load_image_files(self):
        image_files = []
        for root, _, files in os.walk(self.image_folder):
            for file in sorted(files):
                   image_files.append(os.path.join(root, file))
        return image_files

    def load_json_files(self):
        json_files = []
        for root, _, files in os.walk(self.json_folder):
            for file in sorted(files):
                   json_files.append(os.path.join(root, file))
        return json_files

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):

        
        image_path = self.image_files[index]
        json_path = self.json_files[index]
        image = Image.open(image_path)

        image = self.transform(image)
        with open(json_path, 'r') as f:
            json_data = json.load(f)

        disease = json_data["annotations"]["disease"]
        print("index:",index)
        print(" - Disease:", disease)
        return image, json_data

# 이미지 폴더 경로와 JSON 폴더 경로
image_folder_path = '/content/drive/MyDrive/New_sample/원천데이터/01.가지/0.정상'
json_folder_path = '/content/drive/MyDrive/New_sample/라벨링데이터/01.가지/0.정상'

# 전처리 및 변환 설정
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 이미지 크기 조정
    transforms.ToTensor(),  # 이미지를 텐서로 변환
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 이미지 정규화
])

# 이미지와 JSON을 매칭시키는 데이터셋 생성
dataset = ImageJsonDataset(image_folder_path, json_folder_path, transform=transform)

# 데이터로더 생성
data_loader = DataLoader(dataset, batch_size=32, shuffle=False)


In [203]:
learning_rate=0.001
num_classes = 3 
num_epochs=10

#디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#ResNet-50 모델 로드
model = torchvision.models.resnet50(pretrained=True)

#새로운 분류 레이어 추가
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

#모델을 디바이스로 이동
model = model.to(device)

#손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

#모델 학습
for epoch in range(num_epochs):
    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print the loss for monitoring
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")




index: 0
 - Disease: 0
index: 1
 - Disease: 0
index: 2
 - Disease: 0
index: 3
 - Disease: 0
index: 4
 - Disease: 0
index: 5
 - Disease: 0
index: 6
 - Disease: 0
index: 7
 - Disease: 0
index: 8
 - Disease: 0
index: 9
 - Disease: 0
index: 10
 - Disease: 0
index: 11
 - Disease: 0
index: 12
 - Disease: 0
index: 13
 - Disease: 0
index: 14
 - Disease: 0
index: 15
 - Disease: 0
index: 16
 - Disease: 0
index: 17
 - Disease: 0
index: 18
 - Disease: 0
index: 19
 - Disease: 0
index: 20
 - Disease: 0
index: 21
 - Disease: 0
index: 22
 - Disease: 0
index: 23
 - Disease: 0
index: 24
 - Disease: 0
index: 25
 - Disease: 0
index: 26
 - Disease: 0
index: 27
 - Disease: 0
index: 28
 - Disease: 0
index: 29
 - Disease: 0
index: 30
 - Disease: 0
index: 31
 - Disease: 0


TypeError: ignored