In [1]:
!pip install pynvml

Collecting pynvml
  Downloading pynvml-12.0.0-py3-none-any.whl.metadata (5.4 kB)
Collecting nvidia-ml-py<13.0.0a0,>=12.0.0 (from pynvml)
  Downloading nvidia_ml_py-12.560.30-py3-none-any.whl.metadata (8.6 kB)
Downloading pynvml-12.0.0-py3-none-any.whl (26 kB)
Downloading nvidia_ml_py-12.560.30-py3-none-any.whl (40 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nvidia-ml-py, pynvml
Successfully installed nvidia-ml-py-12.560.30 pynvml-12.0.0


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler
import torch.utils.checkpoint as checkpoint

import matplotlib.pyplot as plt
import numpy as np

import copy
from collections import namedtuple
import time
import os
import random
import re

import cv2
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image

from tqdm import tqdm
from pynvml import *
import pandas as pd

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def print_gpu_utilization():
    if torch.cuda.is_available():
        device = torch.cuda.current_device()  # 현재 GPU 디바이스 정보
        allocated_memory = torch.cuda.memory_allocated(device) / 1024**3  # 메모리 사용량 (GB)
        reserved_memory = torch.cuda.memory_reserved(device) / 1024**3  # 예약된 메모리 (GB)
        print(f"Allocated Memory: {allocated_memory:.2f} GB")
        print(f"Reserved Memory: {reserved_memory:.2f} GB")
    else:
        print("No GPU available.")

In [4]:
def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [5]:
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 32

In [6]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size, scale=(0.5, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [7]:
# CIFAR-10
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 78.3MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [8]:
VALID_RATIO = 0.7
n_train_examples = int(len(trainset) * VALID_RATIO)
n_valid_examples = len(trainset) - n_train_examples

train_data, valid_data = data.random_split(trainset, [n_train_examples, n_valid_examples])

In [9]:
valid_data = copy.deepcopy(valid_data)
valid_data.dataset.transform = test_transforms

In [10]:
len(train_data), len(valid_data), len(testset)

(35000, 15000, 10000)

In [11]:
sample_fraction = 0.2

# 무작위 인덱스 생성
train_indices = torch.randperm(len(trainset))[:int(len(trainset) * sample_fraction)]
valid_indices = torch.randperm(len(valid_data))[:int(len(valid_data) * sample_fraction)]
test_indices = torch.randperm(len(testset))[:int(len(testset) * sample_fraction)]

# 서브셋 생성
train_subset = Subset(trainset, train_indices)
valid_subset = Subset(valid_data, valid_indices)
test_subset = Subset(testset, test_indices)

In [12]:
len(train_subset), len(valid_subset), len(test_subset)

(10000, 3000, 2000)

In [13]:
train_iterator = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
valid_iterator = DataLoader(valid_subset, batch_size=batch_size, shuffle=False)
test_iterator = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

In [14]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample = False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        if downsample:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None
        self.downsample = downsample

    def forward(self, x):
        i = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            i = self.downsample(i)

        x += i
        x = self.relu(x)

        return x

In [15]:
class ResNet(nn.Module):
    def __init__(self, config, output_dim, zero_init_residual = False):
        super().__init__()

        block, n_blocks, channels = config
        self.in_channels = channels[0]
        assert len(n_blocks) == len(channels) == 4

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels, output_dim)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
                #elif isinstance(m, Bottleneck):
                    #nn.init.constant_(m.bn3.weight, 0)

    def get_resnet_layer(self, block, n_blocks, channels, stride=1):
        layers = []
        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False
        layers.append(block(self.in_channels, channels, stride, downsample))
        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = checkpoint.checkpoint(self.layer1, x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)
        return x, h

In [16]:
ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])

In [17]:
resnet18_config = ResNetConfig(block = BasicBlock, n_blocks = [2, 2, 2, 2], channels = [64, 128, 256, 512])

In [18]:
model = ResNet(resnet18_config, 10)

In [19]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kerne

In [20]:
optimizer = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

pretrained_model = model.to(device)
criterion = criterion.to(device)

In [21]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [22]:
# 디버깅용 코드 포함 버전 (LifetimeAwareMemoryPool & MelonTrainer)
# class LifetimeAwareMemoryPool:
#     def __init__(self, memory_budget):
#         print(f"[DEBUG] Initializing MemoryPool with budget: {memory_budget}")
#         self.memory_budget = memory_budget
#         self.allocated_memory = 0
#         self.memory_blocks = []
#         self.tensor_map = {}

#     def allocate(self, tensor_id, size, lifetime):
#         print(f"[DEBUG] Attempting to allocate tensor {tensor_id} with size {size} and lifetime {lifetime}")

#         if tensor_id in self.tensor_map:
#             print(f"[DEBUG] Tensor {tensor_id} already allocated")
#             return self.tensor_map[tensor_id]

#         best_addr = self._find_best_fit(size, lifetime)

#         if best_addr is None:
#             print(f"[DEBUG] Memory fragmented, performing compaction")
#             self._compact()
#             best_addr = self._find_best_fit(size, lifetime)
#             if best_addr is None:
#                 print(f"[DEBUG] Failed to allocate memory for tensor {tensor_id}")
#                 raise MemoryError("Not enough memory")

#         block_index = len(self.memory_blocks)
#         self.memory_blocks.append((best_addr, size, tensor_id, lifetime))
#         self.tensor_map[tensor_id] = block_index
#         self.allocated_memory += size

#         print(f"[DEBUG] Successfully allocated tensor {tensor_id} at address {best_addr}")
#         return best_addr

#     def free(self, tensor_id):
#         print(f"[DEBUG] Attempting to free tensor {tensor_id}")
#         if tensor_id in self.tensor_map:
#             block_index = self.tensor_map[tensor_id]
#             _, size, _, _ = self.memory_blocks[block_index]
#             self.allocated_memory -= size
#             del self.tensor_map[tensor_id]
#             self.memory_blocks[block_index] = None
#             print(f"[DEBUG] Successfully freed tensor {tensor_id}")
#         else:
#             print(f"[DEBUG] Tensor {tensor_id} not found in memory pool")

#     def _find_best_fit(self, size, lifetime):
#         print(f"[DEBUG] Finding best fit for size {size} with lifetime {lifetime}")

#         if self.allocated_memory + size > self.memory_budget:
#             print(f"[DEBUG] Not enough memory in budget")
#             return None

#         # 1. 재사용 가능한 메모리 블록 찾기
#         available_addr = 0
#         for block in self.memory_blocks:
#             if block is None:
#                 continue
#             block_addr, block_size, block_id, block_lifetime = block
#             print(f"[DEBUG] Checking block at {block_addr} with size {block_size} (tensor {block_id})")

#             # 수명이 겹치지 않는 경우 해당 공간 재사용
#             if not self._lifetimes_overlap(lifetime, block_lifetime):
#                 print(f"[DEBUG] Found potential reuse block at {block_addr}")
#                 if available_addr == 0:  # 첫 번째로 찾은 재사용 가능한 블록 사용
#                     print(f"[DEBUG] Reusing memory at address {available_addr}")
#                     return available_addr
#             available_addr = max(available_addr, block_addr + block_size)

#         # 2. 새로운 메모리 공간 할당
#         if self.allocated_memory + size <= self.memory_budget:
#             print(f"[DEBUG] Allocating at new address {available_addr}")
#             return available_addr

#         print(f"[DEBUG] No suitable location found")
#         return None

#     def _calculate_address_at_position(self, pos):
#         """주어진 위치에 맞는 메모리 주소 계산"""
#         if pos == 0:
#             return 0
#         prev_block = self.memory_blocks[pos-1]
#         return prev_block[0] + prev_block[1]

#     def _compact(self):
#         print(f"[DEBUG] Starting memory compaction")
#         valid_blocks = [b for b in self.memory_blocks if b is not None]
#         print(f"[DEBUG] Found {len(valid_blocks)} valid blocks")

#         valid_blocks.sort(key=lambda x: x[3])

#         self.memory_blocks = []
#         self.tensor_map.clear()
#         self.allocated_memory = 0

#         current_addr = 0
#         for _, size, tensor_id, lifetime in valid_blocks:
#             self.memory_blocks.append((current_addr, size, tensor_id, lifetime))
#             self.tensor_map[tensor_id] = len(self.memory_blocks) - 1
#             self.allocated_memory += size
#             current_addr += size
#             print(f"[DEBUG] Reallocated tensor {tensor_id} to address {current_addr-size}")

#     def _lifetimes_overlap(self, lifetime1, lifetime2):
#         print(f"[DEBUG] Checking lifetime overlap: {lifetime1} vs {lifetime2}")
#         start1, end1 = lifetime1
#         start2, end2 = lifetime2
#         overlap = not (end1 <= start2 or end2 <= start1)
#         print(f"[DEBUG] Overlap result: {overlap}")
#         return overlap


# class MelonTrainer:
#     def __init__(self, model, criterion, optimizer, device, memory_budget):
#         print(f"[DEBUG] Initializing MelonTrainer with memory budget: {memory_budget}")
#         self.model = model.to(device)
#         self.criterion = criterion
#         self.optimizer = optimizer
#         self.device = device
#         self.memory_budget = memory_budget
#         self.has_bn = self._check_has_bn()
#         print(f"[DEBUG] Model has BatchNorm layers: {self.has_bn}")
#         self.memory_pool = self._initialize_memory_pool()

#     def _check_has_bn(self):
#         print("[DEBUG] Checking for BatchNorm layers in model")
#         for module in self.model.modules():
#             if isinstance(module, nn.BatchNorm2d):
#                 print("[DEBUG] Found BatchNorm2d layer")
#                 return True
#         print("[DEBUG] No BatchNorm layers found")
#         return False

#     def _initialize_memory_pool(self):
#         print("[DEBUG] Initializing LifetimeAwareMemoryPool")
#         return LifetimeAwareMemoryPool(self.memory_budget)

#     def train(self, train_loader):
#         print("[DEBUG] Starting training")
#         start_time = time.monotonic()
#         self.model.train()
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         print("[DEBUG] Setting up profiler")
#         with profile(
#             activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
#             profile_memory=True,
#             record_shapes=True
#         ) as prof:
#             for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training")):
#                 print(f"[DEBUG] Processing batch {batch_idx}")
#                 inputs, labels = inputs.to(self.device), labels.to(self.device)

#                 # Allocate memory for inputs and labels using the memory pool
#                 input_tensor_id = f"input_batch_{batch_idx}"
#                 label_tensor_id = f"label_batch_{batch_idx}"

#                 try:
#                     # Allocate memory for the input and label tensors
#                     self.memory_pool.allocate(input_tensor_id, inputs.element_size() * inputs.nelement(), (batch_idx, batch_idx + 1))
#                     self.memory_pool.allocate(label_tensor_id, labels.element_size() * labels.nelement(), (batch_idx, batch_idx + 1))
#                     print(f"[DEBUG] Allocated memory for input and label tensors for batch {batch_idx}")
#                 except MemoryError as e:
#                     print(f"[ERROR] Memory allocation failed for batch {batch_idx}: {e}")
#                     continue

#                 if self.has_bn:
#                     # BatchNorm이 있는 경우 recomputation 사용
#                     print("[DEBUG] Using recomputation strategy")
#                     loss, acc = self._train_step_with_recomputation(inputs, labels)
#                 else:
#                     # BatchNorm이 없는 경우 micro-batch 사용
#                     print("[DEBUG] Using micro-batch strategy")
#                     loss, acc = self._train_step_with_microbatch(inputs, labels)

#                 running_loss += loss
#                 correct += acc[0]
#                 total += acc[1]

#                 # Free the allocated memory after processing the batch
#                 self.memory_pool.free(input_tensor_id)
#                 self.memory_pool.free(label_tensor_id)
#                 print(f"[DEBUG] Freed memory for input and label tensors for batch {batch_idx}")

#                 print(f"[DEBUG] Batch {batch_idx} - Loss: {loss:.4f}, Accuracy: {acc[0]/acc[1]*100:.2f}%")

#         end_time = time.monotonic()
#         epoch_loss = running_loss / len(train_loader)
#         accuracy = 100 * correct / total if total > 0 else 0

#         print(f"[DEBUG] Training completed - Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")
#         return epoch_loss, accuracy, start_time, end_time

#     def _train_step_with_recomputation(self, inputs, labels):
#         print("[DEBUG] Starting recomputation training step")
#         self.optimizer.zero_grad()

#         # Forward pass with checkpoints
#         with torch.no_grad():
#             intermediate_outputs = []
#             x = inputs

#             print("[DEBUG] Processing initial layers")
#             x = self.model.conv1(x)
#             x = self.model.bn1(x)
#             x = self.model.relu(x)
#             x = self.model.maxpool(x)

#             print("[DEBUG] Processing main layers")
#             for layer_name in ['layer1', 'layer2', 'layer3', 'layer4']:
#                 print(f"[DEBUG] Processing {layer_name}")
#                 layer = getattr(self.model, layer_name)
#                 x = layer(x)
#                 if self.has_bn:
#                     intermediate_outputs.append(x.detach())
#                     print(f"[DEBUG] Saved checkpoint for {layer_name}")

#         print("[DEBUG] Starting forward pass")
#         with record_function("forward_pass"):
#             outputs = self.model(inputs)
#             if isinstance(outputs, tuple):
#                 outputs = outputs[0]

#         print("[DEBUG] Computing loss")
#         with record_function("loss_computation"):
#             loss = self.criterion(outputs, labels)

#         print("[DEBUG] Backward pass")
#         with record_function("backward_pass"):
#             loss.backward()

#         print("[DEBUG] Optimizer step")
#         with record_function("optimizer_step"):
#             self.optimizer.step()

#         _, predicted = torch.max(outputs.data, 1)
#         correct = (predicted == labels).sum().item()
#         total = labels.size(0)

#         print(f"[DEBUG] Step completed - Loss: {loss.item():.4f}, Accuracy: {correct/total*100:.2f}%")
#         return loss.item(), (correct, total)

#     def _calculate_micro_batch_size(self, input_size):
#         print(f"[DEBUG] Calculating micro-batch size for input shape {input_size}")
#         tensor_size = input_size[1] * input_size[2] * input_size[3] * 4
#         micro_batch_size = max(1, min(input_size[0], self.memory_budget // tensor_size))
#         print(f"[DEBUG] Calculated micro-batch size: {micro_batch_size}")
#         return micro_batch_size

In [23]:
class LifetimeAwareMemoryPool:
    def __init__(self, memory_budget):
        self.memory_budget = memory_budget
        self.allocated_memory = 0
        self.memory_blocks = []
        self.tensor_map = {}

    def allocate(self, tensor_id, size, lifetime):
        if tensor_id in self.tensor_map:
            return self.tensor_map[tensor_id]

        best_addr = self._find_best_fit(size, lifetime)

        if best_addr is None:
            self._compact()
            best_addr = self._find_best_fit(size, lifetime)
            if best_addr is None:
                raise MemoryError("Not enough memory")

        block_index = len(self.memory_blocks)
        self.memory_blocks.append((best_addr, size, tensor_id, lifetime))
        self.tensor_map[tensor_id] = block_index
        self.allocated_memory += size

        return best_addr

    def free(self, tensor_id):
        if tensor_id in self.tensor_map:
            block_index = self.tensor_map[tensor_id]
            _, size, _, _ = self.memory_blocks[block_index]
            self.allocated_memory -= size
            del self.tensor_map[tensor_id]
            self.memory_blocks[block_index] = None

    def _find_best_fit(self, size, lifetime):
        if self.allocated_memory + size > self.memory_budget:
            return None

        available_addr = 0
        for block in self.memory_blocks:
            if block is None:
                continue
            block_addr, block_size, block_id, block_lifetime = block
            if not self._lifetimes_overlap(lifetime, block_lifetime):
                if available_addr == 0:
                    return available_addr
            available_addr = max(available_addr, block_addr + block_size)

        if self.allocated_memory + size <= self.memory_budget:
            return available_addr

        return None

    def _calculate_address_at_position(self, pos):
        if pos == 0:
            return 0
        prev_block = self.memory_blocks[pos-1]
        return prev_block[0] + prev_block[1]

    def _compact(self):
        valid_blocks = [b for b in self.memory_blocks if b is not None]
        valid_blocks.sort(key=lambda x: x[3])

        self.memory_blocks = []
        self.tensor_map.clear()
        self.allocated_memory = 0

        current_addr = 0
        for _, size, tensor_id, lifetime in valid_blocks:
            self.memory_blocks.append((current_addr, size, tensor_id, lifetime))
            self.tensor_map[tensor_id] = len(self.memory_blocks) - 1
            self.allocated_memory += size
            current_addr += size

    def _lifetimes_overlap(self, lifetime1, lifetime2):
        start1, end1 = lifetime1
        start2, end2 = lifetime2
        return not (end1 <= start2 or end2 <= start1)

In [24]:
class MelonTrainer:
    def __init__(self, model, criterion, optimizer, device, memory_budget):
        self.model = model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.memory_budget = memory_budget
        self.has_bn = self._check_has_bn()
        self.memory_pool = self._initialize_memory_pool()

    def _check_has_bn(self):
        for module in self.model.modules():
            if isinstance(module, nn.BatchNorm2d):
                return True
        return False

    def _initialize_memory_pool(self):
        return LifetimeAwareMemoryPool(self.memory_budget)

    def train(self, train_loader):
        start_time = time.monotonic()
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training")):
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            # Allocate memory for inputs and labels using the memory pool
            input_tensor_id = f"input_batch_{batch_idx}"
            label_tensor_id = f"label_batch_{batch_idx}"

            try:
                self.memory_pool.allocate(input_tensor_id, inputs.element_size() * inputs.nelement(), (batch_idx, batch_idx + 1))
                self.memory_pool.allocate(label_tensor_id, labels.element_size() * labels.nelement(), (batch_idx, batch_idx + 1))
            except MemoryError:
                continue

            if self.has_bn:
                loss, acc = self._train_step_with_recomputation(inputs, labels)
            else:
                loss, acc = self._train_step_with_microbatch(inputs, labels)

            running_loss += loss
            correct += acc[0]
            total += acc[1]

            self.memory_pool.free(input_tensor_id)
            self.memory_pool.free(label_tensor_id)

        end_time = time.monotonic()


        # 훈련 후 평균 손실과 정확도 계산
        epoch_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total if total > 0 else 0

        return epoch_loss, accuracy, start_time, end_time

    def _train_step_with_recomputation(self, inputs, labels):
        self.optimizer.zero_grad()

        with torch.no_grad():
            intermediate_outputs = []
            x = inputs

            x = self.model.conv1(x)
            x = self.model.bn1(x)
            x = self.model.relu(x)
            x = self.model.maxpool(x)

            for layer_name in ['layer1', 'layer2', 'layer3', 'layer4']:
                layer = getattr(self.model, layer_name)
                x = layer(x)
                if self.has_bn:
                    intermediate_outputs.append(x.detach())

        outputs = self.model(inputs)
        if isinstance(outputs, tuple):
            outputs = outputs[0]

        loss = self.criterion(outputs, labels)

        loss.backward()

        for i, param in enumerate(model.parameters()):
          if i == len(list(model.parameters())) - 1:
            break
          if param.grad is not None:
            grad_values = param.grad.abs().view(-1)
            topk_values, _ = grad_values.topk(10, largest=True)

            threshold = topk_values[-1]

            mask = param.grad.abs() >= threshold

            updated_grad = torch.zeros_like(param.grad)
            updated_grad[mask] = param.grad[mask]

            del param.grad
            torch.cuda.empty_cache()
            param.grad = updated_grad.clone().detach()

        self.optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        total = labels.size(0)

        return loss.item(), (correct, total)

    def _calculate_micro_batch_size(self, input_size):
        tensor_size = input_size[1] * input_size[2] * input_size[3] * 4
        micro_batch_size = max(1, min(input_size[0], self.memory_budget // tensor_size))
        return micro_batch_size

In [25]:
def evaluate(model, data_loader, criterion, device, phase="Validation"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc=f"{phase}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs[0], labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs[0], 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(data_loader)
    accuracy = 100 * correct / total
    # print(f"{phase} Loss: {epoch_loss:.4f}, {phase} Accuracy: {accuracy:.2f}%")

    return epoch_loss, accuracy

In [26]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [27]:
torch.cuda.empty_cache()

In [28]:
print_gpu_utilization()

Allocated Memory: 0.04 GB
Reserved Memory: 0.06 GB


In [29]:
free_memory, total_memory = torch.cuda.mem_get_info()
print(f"Free memory: {free_memory / 1024**2:.2f} MB")
print(f"Total memory: {total_memory / 1024**2:.2f} MB")

Free memory: 40026.81 MB
Total memory: 40513.81 MB


In [30]:
trainer = MelonTrainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    memory_budget=4096 * 1024**2
)

EPOCHS = 50
best_valid_loss = float('inf')
total_time = 0
for epoch in range(EPOCHS):
    train_loss, train_acc, start_time, end_time = trainer.train(train_iterator)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    epoch_duration = epoch_mins * 60 + epoch_secs
    total_time += epoch_duration

    print(f'Epoch: {epoch+1:02} | Epoch Train Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc:.2f}%')

print("Train finished")
#print_gpu_utilization()

  return fn(*args, **kwargs)
Training: 100%|██████████| 313/313 [00:35<00:00,  8.74it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.42it/s]


Epoch: 01 | Epoch Train Time: 0m 35s
	Train Loss: 2.382 | Train Acc: 9.86%
	 Val. Loss: 2.373 |  Val. Acc: 10.27%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.42it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.25it/s]


Epoch: 02 | Epoch Train Time: 0m 33s
	Train Loss: 2.348 | Train Acc: 10.30%
	 Val. Loss: 2.336 |  Val. Acc: 11.03%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.40it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.90it/s]


Epoch: 03 | Epoch Train Time: 0m 33s
	Train Loss: 2.325 | Train Acc: 10.47%
	 Val. Loss: 2.313 |  Val. Acc: 10.97%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.36it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.49it/s]


Epoch: 04 | Epoch Train Time: 0m 33s
	Train Loss: 2.305 | Train Acc: 10.88%
	 Val. Loss: 2.299 |  Val. Acc: 11.17%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.40it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.41it/s]


Epoch: 05 | Epoch Train Time: 0m 33s
	Train Loss: 2.290 | Train Acc: 11.73%
	 Val. Loss: 2.278 |  Val. Acc: 12.63%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.37it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.00it/s]


Epoch: 06 | Epoch Train Time: 0m 33s
	Train Loss: 2.277 | Train Acc: 12.55%
	 Val. Loss: 2.271 |  Val. Acc: 12.90%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.45it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.15it/s]


Epoch: 07 | Epoch Train Time: 0m 33s
	Train Loss: 2.262 | Train Acc: 13.50%
	 Val. Loss: 2.258 |  Val. Acc: 13.57%


Training: 100%|██████████| 313/313 [00:32<00:00,  9.54it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.64it/s]


Epoch: 08 | Epoch Train Time: 0m 32s
	Train Loss: 2.252 | Train Acc: 13.69%
	 Val. Loss: 2.250 |  Val. Acc: 14.10%


Training: 100%|██████████| 313/313 [00:32<00:00,  9.52it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.69it/s]


Epoch: 09 | Epoch Train Time: 0m 32s
	Train Loss: 2.244 | Train Acc: 14.40%
	 Val. Loss: 2.246 |  Val. Acc: 15.30%


Training: 100%|██████████| 313/313 [00:32<00:00,  9.54it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.63it/s]


Epoch: 10 | Epoch Train Time: 0m 32s
	Train Loss: 2.236 | Train Acc: 15.01%
	 Val. Loss: 2.232 |  Val. Acc: 15.60%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.46it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.14it/s]


Epoch: 11 | Epoch Train Time: 0m 33s
	Train Loss: 2.229 | Train Acc: 15.69%
	 Val. Loss: 2.232 |  Val. Acc: 15.93%


Training: 100%|██████████| 313/313 [00:32<00:00,  9.51it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.29it/s]


Epoch: 12 | Epoch Train Time: 0m 32s
	Train Loss: 2.221 | Train Acc: 15.89%
	 Val. Loss: 2.225 |  Val. Acc: 15.70%


Training: 100%|██████████| 313/313 [00:32<00:00,  9.50it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.47it/s]


Epoch: 13 | Epoch Train Time: 0m 32s
	Train Loss: 2.215 | Train Acc: 16.78%
	 Val. Loss: 2.218 |  Val. Acc: 16.63%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.43it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.60it/s]


Epoch: 14 | Epoch Train Time: 0m 33s
	Train Loss: 2.212 | Train Acc: 16.93%
	 Val. Loss: 2.219 |  Val. Acc: 16.27%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.42it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.79it/s]


Epoch: 15 | Epoch Train Time: 0m 33s
	Train Loss: 2.206 | Train Acc: 17.32%
	 Val. Loss: 2.204 |  Val. Acc: 17.90%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.40it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.06it/s]


Epoch: 16 | Epoch Train Time: 0m 33s
	Train Loss: 2.196 | Train Acc: 18.22%
	 Val. Loss: 2.203 |  Val. Acc: 17.47%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.42it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.36it/s]


Epoch: 17 | Epoch Train Time: 0m 33s
	Train Loss: 2.192 | Train Acc: 18.71%
	 Val. Loss: 2.195 |  Val. Acc: 18.53%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.38it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.73it/s]


Epoch: 18 | Epoch Train Time: 0m 33s
	Train Loss: 2.187 | Train Acc: 18.59%
	 Val. Loss: 2.189 |  Val. Acc: 19.30%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.41it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.37it/s]


Epoch: 19 | Epoch Train Time: 0m 33s
	Train Loss: 2.183 | Train Acc: 19.27%
	 Val. Loss: 2.185 |  Val. Acc: 20.03%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.39it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.35it/s]


Epoch: 20 | Epoch Train Time: 0m 33s
	Train Loss: 2.180 | Train Acc: 18.93%
	 Val. Loss: 2.183 |  Val. Acc: 19.40%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.39it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.48it/s]


Epoch: 21 | Epoch Train Time: 0m 33s
	Train Loss: 2.170 | Train Acc: 19.92%
	 Val. Loss: 2.181 |  Val. Acc: 19.73%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.34it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.36it/s]


Epoch: 22 | Epoch Train Time: 0m 33s
	Train Loss: 2.166 | Train Acc: 20.67%
	 Val. Loss: 2.170 |  Val. Acc: 20.07%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.45it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.93it/s]


Epoch: 23 | Epoch Train Time: 0m 33s
	Train Loss: 2.164 | Train Acc: 20.45%
	 Val. Loss: 2.166 |  Val. Acc: 20.43%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.39it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.77it/s]


Epoch: 24 | Epoch Train Time: 0m 33s
	Train Loss: 2.154 | Train Acc: 20.59%
	 Val. Loss: 2.163 |  Val. Acc: 20.73%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.40it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.74it/s]


Epoch: 25 | Epoch Train Time: 0m 33s
	Train Loss: 2.154 | Train Acc: 21.08%
	 Val. Loss: 2.156 |  Val. Acc: 21.60%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.32it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.93it/s]


Epoch: 26 | Epoch Train Time: 0m 33s
	Train Loss: 2.151 | Train Acc: 21.15%
	 Val. Loss: 2.152 |  Val. Acc: 21.87%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.44it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.29it/s]


Epoch: 27 | Epoch Train Time: 0m 33s
	Train Loss: 2.144 | Train Acc: 21.52%
	 Val. Loss: 2.151 |  Val. Acc: 21.00%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.43it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.58it/s]


Epoch: 28 | Epoch Train Time: 0m 33s
	Train Loss: 2.141 | Train Acc: 21.56%
	 Val. Loss: 2.137 |  Val. Acc: 21.40%


Training: 100%|██████████| 313/313 [00:32<00:00,  9.49it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.42it/s]


Epoch: 29 | Epoch Train Time: 0m 32s
	Train Loss: 2.136 | Train Acc: 21.50%
	 Val. Loss: 2.138 |  Val. Acc: 21.63%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.48it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.92it/s]


Epoch: 30 | Epoch Train Time: 0m 33s
	Train Loss: 2.135 | Train Acc: 21.97%
	 Val. Loss: 2.131 |  Val. Acc: 22.07%


Training: 100%|██████████| 313/313 [00:32<00:00,  9.54it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.56it/s]


Epoch: 31 | Epoch Train Time: 0m 32s
	Train Loss: 2.126 | Train Acc: 22.20%
	 Val. Loss: 2.130 |  Val. Acc: 21.77%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.46it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.71it/s]


Epoch: 32 | Epoch Train Time: 0m 33s
	Train Loss: 2.122 | Train Acc: 22.13%
	 Val. Loss: 2.122 |  Val. Acc: 22.03%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.41it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.74it/s]


Epoch: 33 | Epoch Train Time: 0m 33s
	Train Loss: 2.121 | Train Acc: 21.77%
	 Val. Loss: 2.122 |  Val. Acc: 22.23%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.45it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.29it/s]


Epoch: 34 | Epoch Train Time: 0m 33s
	Train Loss: 2.117 | Train Acc: 22.28%
	 Val. Loss: 2.110 |  Val. Acc: 22.17%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.43it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.55it/s]


Epoch: 35 | Epoch Train Time: 0m 33s
	Train Loss: 2.110 | Train Acc: 22.47%
	 Val. Loss: 2.120 |  Val. Acc: 22.33%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.47it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.47it/s]


Epoch: 36 | Epoch Train Time: 0m 33s
	Train Loss: 2.108 | Train Acc: 22.27%
	 Val. Loss: 2.113 |  Val. Acc: 22.33%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.44it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.83it/s]


Epoch: 37 | Epoch Train Time: 0m 33s
	Train Loss: 2.107 | Train Acc: 22.85%
	 Val. Loss: 2.109 |  Val. Acc: 23.03%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.45it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.39it/s]


Epoch: 38 | Epoch Train Time: 0m 33s
	Train Loss: 2.104 | Train Acc: 22.32%
	 Val. Loss: 2.102 |  Val. Acc: 22.97%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.43it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.86it/s]


Epoch: 39 | Epoch Train Time: 0m 33s
	Train Loss: 2.098 | Train Acc: 23.18%
	 Val. Loss: 2.098 |  Val. Acc: 23.13%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.47it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.45it/s]


Epoch: 40 | Epoch Train Time: 0m 33s
	Train Loss: 2.094 | Train Acc: 22.92%
	 Val. Loss: 2.088 |  Val. Acc: 23.47%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.39it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.39it/s]


Epoch: 41 | Epoch Train Time: 0m 33s
	Train Loss: 2.089 | Train Acc: 23.21%
	 Val. Loss: 2.085 |  Val. Acc: 23.23%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.39it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.32it/s]


Epoch: 42 | Epoch Train Time: 0m 33s
	Train Loss: 2.087 | Train Acc: 23.15%
	 Val. Loss: 2.079 |  Val. Acc: 23.77%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.41it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.83it/s]


Epoch: 43 | Epoch Train Time: 0m 33s
	Train Loss: 2.084 | Train Acc: 23.61%
	 Val. Loss: 2.087 |  Val. Acc: 24.10%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.40it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.65it/s]


Epoch: 44 | Epoch Train Time: 0m 33s
	Train Loss: 2.080 | Train Acc: 23.40%
	 Val. Loss: 2.075 |  Val. Acc: 24.20%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.35it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.52it/s]


Epoch: 45 | Epoch Train Time: 0m 33s
	Train Loss: 2.072 | Train Acc: 23.89%
	 Val. Loss: 2.070 |  Val. Acc: 24.40%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.34it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.29it/s]


Epoch: 46 | Epoch Train Time: 0m 33s
	Train Loss: 2.071 | Train Acc: 23.67%
	 Val. Loss: 2.066 |  Val. Acc: 24.47%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.42it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.81it/s]


Epoch: 47 | Epoch Train Time: 0m 33s
	Train Loss: 2.068 | Train Acc: 24.25%
	 Val. Loss: 2.066 |  Val. Acc: 24.77%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.35it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.58it/s]


Epoch: 48 | Epoch Train Time: 0m 33s
	Train Loss: 2.063 | Train Acc: 23.99%
	 Val. Loss: 2.067 |  Val. Acc: 24.67%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.39it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 20.48it/s]


Epoch: 49 | Epoch Train Time: 0m 33s
	Train Loss: 2.062 | Train Acc: 24.64%
	 Val. Loss: 2.054 |  Val. Acc: 24.53%


Training: 100%|██████████| 313/313 [00:33<00:00,  9.40it/s]
Validation: 100%|██████████| 94/94 [00:04<00:00, 19.72it/s]

Epoch: 50 | Epoch Train Time: 0m 33s
	Train Loss: 2.060 | Train Acc: 24.25%
	 Val. Loss: 2.052 |  Val. Acc: 24.23%
Train finished





In [31]:
print_gpu_utilization()

Allocated Memory: 0.10 GB
Reserved Memory: 0.44 GB


In [32]:
free_memory, total_memory = torch.cuda.mem_get_info()
print(f"Free memory: {free_memory / 1024**2:.2f} MB")
print(f"Total memory: {total_memory / 1024**2:.2f} MB")

Free memory: 39544.81 MB
Total memory: 40513.81 MB


In [33]:
print("ResNet18 with Melon")
print(f'Total Training Time: {int(total_time/60)}m {int(total_time%60)}s')

ResNet18 with Melon
Total Training Time: 27m 25s


In [34]:
from torch import profiler

dummy_input = torch.randn(32, 3, 224, 224).cuda()

# Profiling inference
with profiler.profile(
    activities=[
       profiler.ProfilerActivity.CPU,
        profiler.ProfilerActivity.CUDA,  # Include if using GPU
    ],
    on_trace_ready=profiler.tensorboard_trace_handler("./logs"),  # Optional logging
    record_shapes=True,
    with_stack=True
) as prof:
    with torch.no_grad():
        model(dummy_input)


# Print results
print(prof.key_averages().table(sort_by="cuda_time_total" if torch.cuda.is_available() else "cpu_time_total", row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::conv2d         1.89%     313.629us        53.52%       8.889ms     444.458us       0.000us         0.00%       2.695ms     134.774us            20  
                                      aten::convolution         1.19%     197.234us        51.63%       8.576ms     428.777us       0.000us         0.00%       2.695ms     134.774us            20  
         