In [1]:
import torch
import torch.nn as nn

In [2]:
class R2Plus1D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, stride=1):
        super(R2Plus1D_Block, self).__init__()
        if mid_channels is None:
            mid_channels = out_channels

        self.spatial_conv = nn.Conv3d(
            in_channels,
            mid_channels,
            kernel_size=(1, 3, 3),
            stride=(1, stride, stride),
            padding=(0, 1, 1),
            bias=False
        )

        self.bn1 = nn.BatchNorm3d(mid_channels)

        self.temporal_conv = nn.Conv3d(
            mid_channels,
            out_channels,
            kernel_size=(3, 1, 1),
            stride=(stride, 1, 1),
            padding=(1, 0, 0),
            bias=False
        )

        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.spatial_conv(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.temporal_conv(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x


In [3]:
class R2Plus1DNet(nn.Module):
    def __init__(self, num_classes=400):
        super(R2Plus1DNet, self).__init__()
        self.layer1 = R2Plus1D_Block(3, 64, stride=1)
        self.layer2 = R2Plus1D_Block(64, 128, stride=2)
        self.layer3 = R2Plus1D_Block(128, 128, stride=2)
        self.layer4 = R2Plus1D_Block(128, 64, stride=2)

        self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [4]:
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {DEVICE} device")
model = R2Plus1DNet(num_classes=1).to(DEVICE)

Using cuda device


In [5]:
from dataset_for_fact_CNN import create_dataloaders

data_dir = '../data/Training/01.원천데이터/이미지'
train_dataloader = create_dataloaders(data_dir, test_ratio = 0, batch_size=8, image_size = 160, workers=16)
data_dir = '../data/Validation/01.원천데이터/이미지'
val_dataloader = create_dataloaders(data_dir, test_ratio = 0, batch_size=8, image_size = 160, workers = 16)

In [6]:
import torch.optim as optim

learning_rate = 0.0001
EPOCHS = 10
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCEWithLogitsLoss()

In [7]:
import torch
from tqdm import tqdm

def val(model, val_loader, criterion, device='cpu'):
    model.eval()
    val_loss = 0.0
    correct_preds = 0
    total_preds = 0

    with torch.no_grad(): 
        for img, labels in tqdm(val_loader):
            img = img.to(device)
            labels = labels.to(device, dtype=torch.float32)

            output = model(img)
            loss = criterion(output.view(-1), labels)
            val_loss += loss.item()

            preds = (output.view(-1) > 0.5).float()
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = correct_preds / total_preds * 100

    print(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    
    return avg_val_loss, val_accuracy

In [8]:
def train(model, train_loader, val_loader, optimizer, criterion, device='cpu', epoches=10):
    model.train()
    best_val_acc = 0
    for epoch in range(epoches):
        epoch_loss = 0.0
        correct_preds = 0
        total_preds = 0
        
        for img, labels in tqdm(train_loader):
            img = img.to(device)
            labels = labels.to(device, dtype=torch.float32) 
            optimizer.zero_grad()
            
            output = model(img)
            
            loss = criterion(output.view(-1), labels)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()

            preds = (output.view(-1) > 0.5).float()
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)

        avg_loss = epoch_loss / len(train_loader)
        accuracy = correct_preds / total_preds * 100

        print(f"Epoch [{epoch+1}/{epoches}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

        avg_val_loss, val_acc = val(model, val_loader, criterion, device)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model_save_path = 'model.pth'
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch': EPOCHS,
                'loss': avg_loss
            }, model_save_path)
            print(f"Model and parameters saved to {model_save_path}")
            
    return avg_val_loss, val_acc
        

In [9]:
avg_loss, acc = train(model, train_dataloader, val_dataloader, optimizer, criterion, device=DEVICE, epoches=EPOCHS)

100%|██████████| 2266/2266 [21:51<00:00,  1.73it/s]


Epoch [1/10], Loss: 0.4385, Accuracy: 82.44%


100%|██████████| 284/284 [02:05<00:00,  2.26it/s]


Validation Loss: 0.2773, Validation Accuracy: 94.67%
Model and parameters saved to model.pth


100%|██████████| 2266/2266 [21:20<00:00,  1.77it/s]


Epoch [2/10], Loss: 0.2279, Accuracy: 90.95%


100%|██████████| 284/284 [02:06<00:00,  2.24it/s]


Validation Loss: 0.1883, Validation Accuracy: 93.09%


100%|██████████| 2266/2266 [21:22<00:00,  1.77it/s]


Epoch [3/10], Loss: 0.1422, Accuracy: 94.69%


100%|██████████| 284/284 [02:05<00:00,  2.26it/s]


Validation Loss: 0.1603, Validation Accuracy: 94.01%


100%|██████████| 2266/2266 [21:22<00:00,  1.77it/s]


Epoch [4/10], Loss: 0.1001, Accuracy: 96.30%


100%|██████████| 284/284 [02:01<00:00,  2.33it/s]


Validation Loss: 0.1553, Validation Accuracy: 96.96%
Model and parameters saved to model.pth


100%|██████████| 2266/2266 [21:23<00:00,  1.77it/s]


Epoch [5/10], Loss: 0.0754, Accuracy: 97.49%


100%|██████████| 284/284 [02:07<00:00,  2.23it/s]


Validation Loss: 0.1168, Validation Accuracy: 96.96%


100%|██████████| 2266/2266 [21:21<00:00,  1.77it/s]


Epoch [6/10], Loss: 0.0717, Accuracy: 97.62%


100%|██████████| 284/284 [02:03<00:00,  2.30it/s]


Validation Loss: 0.0897, Validation Accuracy: 97.36%
Model and parameters saved to model.pth


100%|██████████| 2266/2266 [21:23<00:00,  1.77it/s]


Epoch [7/10], Loss: 0.0549, Accuracy: 98.14%


100%|██████████| 284/284 [02:04<00:00,  2.28it/s]


Validation Loss: 0.1182, Validation Accuracy: 95.38%


100%|██████████| 2266/2266 [21:22<00:00,  1.77it/s]


Epoch [8/10], Loss: 0.0517, Accuracy: 98.41%


100%|██████████| 284/284 [02:02<00:00,  2.31it/s]


Validation Loss: 0.1279, Validation Accuracy: 97.05%


100%|██████████| 2266/2266 [21:23<00:00,  1.77it/s]


Epoch [9/10], Loss: 0.0453, Accuracy: 98.55%


100%|██████████| 284/284 [02:04<00:00,  2.29it/s]


Validation Loss: 0.1000, Validation Accuracy: 98.28%
Model and parameters saved to model.pth


100%|██████████| 2266/2266 [21:22<00:00,  1.77it/s]


Epoch [10/10], Loss: 0.0391, Accuracy: 98.80%


100%|██████████| 284/284 [02:06<00:00,  2.25it/s]

Validation Loss: 0.1501, Validation Accuracy: 97.76%



