In [None]:
!pip install pynvml

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler

import matplotlib.pyplot as plt
import numpy as np

import copy
from collections import namedtuple
import time
import os
import random

import cv2
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image

from tqdm import tqdm
from pynvml import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def print_gpu_utilization():
    if torch.cuda.is_available():
        device = torch.cuda.current_device()  # 현재 GPU 디바이스 정보
        allocated_memory = torch.cuda.memory_allocated(device) / 1024**3  # 메모리 사용량 (GB)
        reserved_memory = torch.cuda.memory_reserved(device) / 1024**3  # 예약된 메모리 (GB)
        print(f"Allocated Memory: {allocated_memory:.2f} GB")
        print(f"Reserved Memory: {reserved_memory:.2f} GB")
    else:
        print("No GPU available.")

In [None]:
def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()

In [None]:
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 32

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size, scale=(0.5, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [None]:
# CIFAR-10
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [None]:
VALID_RATIO = 0.7
n_train_examples = int(len(trainset) * VALID_RATIO)
n_valid_examples = len(trainset) - n_train_examples

train_data, valid_data = data.random_split(trainset, [n_train_examples, n_valid_examples])

In [None]:
valid_data = copy.deepcopy(valid_data)
valid_data.dataset.transform = test_transforms

In [None]:
len(train_data), len(valid_data), len(testset)

In [None]:
sample_fraction = 0.2

# 무작위 인덱스 생성
train_indices = torch.randperm(len(trainset))[:int(len(trainset) * sample_fraction)]
valid_indices = torch.randperm(len(valid_data))[:int(len(valid_data) * sample_fraction)]
test_indices = torch.randperm(len(testset))[:int(len(testset) * sample_fraction)]

# 서브셋 생성
train_subset = Subset(trainset, train_indices)
valid_subset = Subset(valid_data, valid_indices)
test_subset = Subset(testset, test_indices)

In [None]:
len(train_subset), len(valid_subset), len(test_subset)

In [None]:
train_iterator = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
valid_iterator = DataLoader(valid_subset, batch_size=batch_size, shuffle=False)
test_iterator = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample = False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        if downsample:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None
        self.downsample = downsample

    def forward(self, x):
        i = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            i = self.downsample(i)

        x += i
        x = self.relu(x)

        return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, config, output_dim, zero_init_residual = False):
        super().__init__()

        block, n_blocks, channels = config
        self.in_channels = channels[0]
        assert len(n_blocks) == len(channels) == 4

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels, output_dim)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
                #elif isinstance(m, Bottleneck):
                    #nn.init.constant_(m.bn3.weight, 0)

    def get_resnet_layer(self, block, n_blocks, channels, stride=1):
        layers = []
        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False
        layers.append(block(self.in_channels, channels, stride, downsample))
        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)
        return x, h

In [None]:
ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])

In [None]:
resnet18_config = ResNetConfig(block = BasicBlock, n_blocks = [2, 2, 2, 2], channels = [64, 128, 256, 512])

In [None]:
pretrained_model = models.resnet18(pretrained=True)

In [None]:
print(pretrained_model)

In [None]:
model = ResNet(resnet18_config, 10)

In [None]:
print(model)

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

pretrained_model = model.to(device)
criterion = criterion.to(device)

In [None]:
def calculate_accuracy(y_pred, y):
    top_pred = y_pred.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [None]:
# 디버깅용 코드 포함 버전 (LifetimeAwareMemoryPool & MelonTrainer)
# class LifetimeAwareMemoryPool:
#     def __init__(self, memory_budget):
#         print(f"[DEBUG] Initializing MemoryPool with budget: {memory_budget}")
#         self.memory_budget = memory_budget
#         self.allocated_memory = 0
#         self.memory_blocks = []
#         self.tensor_map = {}

#     def allocate(self, tensor_id, size, lifetime):
#         print(f"[DEBUG] Attempting to allocate tensor {tensor_id} with size {size} and lifetime {lifetime}")

#         if tensor_id in self.tensor_map:
#             print(f"[DEBUG] Tensor {tensor_id} already allocated")
#             return self.tensor_map[tensor_id]

#         best_addr = self._find_best_fit(size, lifetime)

#         if best_addr is None:
#             print(f"[DEBUG] Memory fragmented, performing compaction")
#             self._compact()
#             best_addr = self._find_best_fit(size, lifetime)
#             if best_addr is None:
#                 print(f"[DEBUG] Failed to allocate memory for tensor {tensor_id}")
#                 raise MemoryError("Not enough memory")

#         block_index = len(self.memory_blocks)
#         self.memory_blocks.append((best_addr, size, tensor_id, lifetime))
#         self.tensor_map[tensor_id] = block_index
#         self.allocated_memory += size

#         print(f"[DEBUG] Successfully allocated tensor {tensor_id} at address {best_addr}")
#         return best_addr

#     def free(self, tensor_id):
#         print(f"[DEBUG] Attempting to free tensor {tensor_id}")
#         if tensor_id in self.tensor_map:
#             block_index = self.tensor_map[tensor_id]
#             _, size, _, _ = self.memory_blocks[block_index]
#             self.allocated_memory -= size
#             del self.tensor_map[tensor_id]
#             self.memory_blocks[block_index] = None
#             print(f"[DEBUG] Successfully freed tensor {tensor_id}")
#         else:
#             print(f"[DEBUG] Tensor {tensor_id} not found in memory pool")

#     def _find_best_fit(self, size, lifetime):
#         print(f"[DEBUG] Finding best fit for size {size} with lifetime {lifetime}")

#         if self.allocated_memory + size > self.memory_budget:
#             print(f"[DEBUG] Not enough memory in budget")
#             return None

#         # 1. 재사용 가능한 메모리 블록 찾기
#         available_addr = 0
#         for block in self.memory_blocks:
#             if block is None:
#                 continue
#             block_addr, block_size, block_id, block_lifetime = block
#             print(f"[DEBUG] Checking block at {block_addr} with size {block_size} (tensor {block_id})")

#             # 수명이 겹치지 않는 경우 해당 공간 재사용
#             if not self._lifetimes_overlap(lifetime, block_lifetime):
#                 print(f"[DEBUG] Found potential reuse block at {block_addr}")
#                 if available_addr == 0:  # 첫 번째로 찾은 재사용 가능한 블록 사용
#                     print(f"[DEBUG] Reusing memory at address {available_addr}")
#                     return available_addr
#             available_addr = max(available_addr, block_addr + block_size)

#         # 2. 새로운 메모리 공간 할당
#         if self.allocated_memory + size <= self.memory_budget:
#             print(f"[DEBUG] Allocating at new address {available_addr}")
#             return available_addr

#         print(f"[DEBUG] No suitable location found")
#         return None

#     def _calculate_address_at_position(self, pos):
#         """주어진 위치에 맞는 메모리 주소 계산"""
#         if pos == 0:
#             return 0
#         prev_block = self.memory_blocks[pos-1]
#         return prev_block[0] + prev_block[1]

#     def _compact(self):
#         print(f"[DEBUG] Starting memory compaction")
#         valid_blocks = [b for b in self.memory_blocks if b is not None]
#         print(f"[DEBUG] Found {len(valid_blocks)} valid blocks")

#         valid_blocks.sort(key=lambda x: x[3])

#         self.memory_blocks = []
#         self.tensor_map.clear()
#         self.allocated_memory = 0

#         current_addr = 0
#         for _, size, tensor_id, lifetime in valid_blocks:
#             self.memory_blocks.append((current_addr, size, tensor_id, lifetime))
#             self.tensor_map[tensor_id] = len(self.memory_blocks) - 1
#             self.allocated_memory += size
#             current_addr += size
#             print(f"[DEBUG] Reallocated tensor {tensor_id} to address {current_addr-size}")

#     def _lifetimes_overlap(self, lifetime1, lifetime2):
#         print(f"[DEBUG] Checking lifetime overlap: {lifetime1} vs {lifetime2}")
#         start1, end1 = lifetime1
#         start2, end2 = lifetime2
#         overlap = not (end1 <= start2 or end2 <= start1)
#         print(f"[DEBUG] Overlap result: {overlap}")
#         return overlap


# class MelonTrainer:
#     def __init__(self, model, criterion, optimizer, device, memory_budget):
#         print(f"[DEBUG] Initializing MelonTrainer with memory budget: {memory_budget}")
#         self.model = model.to(device)
#         self.criterion = criterion
#         self.optimizer = optimizer
#         self.device = device
#         self.memory_budget = memory_budget
#         self.has_bn = self._check_has_bn()
#         print(f"[DEBUG] Model has BatchNorm layers: {self.has_bn}")
#         self.memory_pool = self._initialize_memory_pool()

#     def _check_has_bn(self):
#         print("[DEBUG] Checking for BatchNorm layers in model")
#         for module in self.model.modules():
#             if isinstance(module, nn.BatchNorm2d):
#                 print("[DEBUG] Found BatchNorm2d layer")
#                 return True
#         print("[DEBUG] No BatchNorm layers found")
#         return False

#     def _initialize_memory_pool(self):
#         print("[DEBUG] Initializing LifetimeAwareMemoryPool")
#         return LifetimeAwareMemoryPool(self.memory_budget)

#     def train(self, train_loader):
#         print("[DEBUG] Starting training")
#         start_time = time.monotonic()
#         self.model.train()
#         running_loss = 0.0
#         correct = 0
#         total = 0

#         print("[DEBUG] Setting up profiler")
#         with profile(
#             activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
#             profile_memory=True,
#             record_shapes=True
#         ) as prof:
#             for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training")):
#                 print(f"[DEBUG] Processing batch {batch_idx}")
#                 inputs, labels = inputs.to(self.device), labels.to(self.device)

#                 # Allocate memory for inputs and labels using the memory pool
#                 input_tensor_id = f"input_batch_{batch_idx}"
#                 label_tensor_id = f"label_batch_{batch_idx}"

#                 try:
#                     # Allocate memory for the input and label tensors
#                     self.memory_pool.allocate(input_tensor_id, inputs.element_size() * inputs.nelement(), (batch_idx, batch_idx + 1))
#                     self.memory_pool.allocate(label_tensor_id, labels.element_size() * labels.nelement(), (batch_idx, batch_idx + 1))
#                     print(f"[DEBUG] Allocated memory for input and label tensors for batch {batch_idx}")
#                 except MemoryError as e:
#                     print(f"[ERROR] Memory allocation failed for batch {batch_idx}: {e}")
#                     continue

#                 if self.has_bn:
#                     # BatchNorm이 있는 경우 recomputation 사용
#                     print("[DEBUG] Using recomputation strategy")
#                     loss, acc = self._train_step_with_recomputation(inputs, labels)
#                 else:
#                     # BatchNorm이 없는 경우 micro-batch 사용
#                     print("[DEBUG] Using micro-batch strategy")
#                     loss, acc = self._train_step_with_microbatch(inputs, labels)

#                 running_loss += loss
#                 correct += acc[0]
#                 total += acc[1]

#                 # Free the allocated memory after processing the batch
#                 self.memory_pool.free(input_tensor_id)
#                 self.memory_pool.free(label_tensor_id)
#                 print(f"[DEBUG] Freed memory for input and label tensors for batch {batch_idx}")

#                 print(f"[DEBUG] Batch {batch_idx} - Loss: {loss:.4f}, Accuracy: {acc[0]/acc[1]*100:.2f}%")

#         end_time = time.monotonic()
#         epoch_loss = running_loss / len(train_loader)
#         accuracy = 100 * correct / total if total > 0 else 0

#         print(f"[DEBUG] Training completed - Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")
#         return epoch_loss, accuracy, start_time, end_time

#     def _train_step_with_recomputation(self, inputs, labels):
#         print("[DEBUG] Starting recomputation training step")
#         self.optimizer.zero_grad()

#         # Forward pass with checkpoints
#         with torch.no_grad():
#             intermediate_outputs = []
#             x = inputs

#             print("[DEBUG] Processing initial layers")
#             x = self.model.conv1(x)
#             x = self.model.bn1(x)
#             x = self.model.relu(x)
#             x = self.model.maxpool(x)

#             print("[DEBUG] Processing main layers")
#             for layer_name in ['layer1', 'layer2', 'layer3', 'layer4']:
#                 print(f"[DEBUG] Processing {layer_name}")
#                 layer = getattr(self.model, layer_name)
#                 x = layer(x)
#                 if self.has_bn:
#                     intermediate_outputs.append(x.detach())
#                     print(f"[DEBUG] Saved checkpoint for {layer_name}")

#         print("[DEBUG] Starting forward pass")
#         with record_function("forward_pass"):
#             outputs = self.model(inputs)
#             if isinstance(outputs, tuple):
#                 outputs = outputs[0]

#         print("[DEBUG] Computing loss")
#         with record_function("loss_computation"):
#             loss = self.criterion(outputs, labels)

#         print("[DEBUG] Backward pass")
#         with record_function("backward_pass"):
#             loss.backward()

#         print("[DEBUG] Optimizer step")
#         with record_function("optimizer_step"):
#             self.optimizer.step()

#         _, predicted = torch.max(outputs.data, 1)
#         correct = (predicted == labels).sum().item()
#         total = labels.size(0)

#         print(f"[DEBUG] Step completed - Loss: {loss.item():.4f}, Accuracy: {correct/total*100:.2f}%")
#         return loss.item(), (correct, total)

#     def _calculate_micro_batch_size(self, input_size):
#         print(f"[DEBUG] Calculating micro-batch size for input shape {input_size}")
#         tensor_size = input_size[1] * input_size[2] * input_size[3] * 4
#         micro_batch_size = max(1, min(input_size[0], self.memory_budget // tensor_size))
#         print(f"[DEBUG] Calculated micro-batch size: {micro_batch_size}")
#         return micro_batch_size

In [None]:
class LifetimeAwareMemoryPool:
    def __init__(self, memory_budget):
        self.memory_budget = memory_budget
        self.allocated_memory = 0
        self.memory_blocks = []
        self.tensor_map = {}

    def allocate(self, tensor_id, size, lifetime):
        if tensor_id in self.tensor_map:
            return self.tensor_map[tensor_id]

        best_addr = self._find_best_fit(size, lifetime)

        if best_addr is None:
            self._compact()
            best_addr = self._find_best_fit(size, lifetime)
            if best_addr is None:
                raise MemoryError("Not enough memory")

        block_index = len(self.memory_blocks)
        self.memory_blocks.append((best_addr, size, tensor_id, lifetime))
        self.tensor_map[tensor_id] = block_index
        self.allocated_memory += size

        return best_addr

    def free(self, tensor_id):
        if tensor_id in self.tensor_map:
            block_index = self.tensor_map[tensor_id]
            _, size, _, _ = self.memory_blocks[block_index]
            self.allocated_memory -= size
            del self.tensor_map[tensor_id]
            self.memory_blocks[block_index] = None

    def _find_best_fit(self, size, lifetime):
        if self.allocated_memory + size > self.memory_budget:
            return None

        available_addr = 0
        for block in self.memory_blocks:
            if block is None:
                continue
            block_addr, block_size, block_id, block_lifetime = block
            if not self._lifetimes_overlap(lifetime, block_lifetime):
                if available_addr == 0:
                    return available_addr
            available_addr = max(available_addr, block_addr + block_size)

        if self.allocated_memory + size <= self.memory_budget:
            return available_addr

        return None

    def _calculate_address_at_position(self, pos):
        if pos == 0:
            return 0
        prev_block = self.memory_blocks[pos-1]
        return prev_block[0] + prev_block[1]

    def _compact(self):
        valid_blocks = [b for b in self.memory_blocks if b is not None]
        valid_blocks.sort(key=lambda x: x[3])

        self.memory_blocks = []
        self.tensor_map.clear()
        self.allocated_memory = 0

        current_addr = 0
        for _, size, tensor_id, lifetime in valid_blocks:
            self.memory_blocks.append((current_addr, size, tensor_id, lifetime))
            self.tensor_map[tensor_id] = len(self.memory_blocks) - 1
            self.allocated_memory += size
            current_addr += size

    def _lifetimes_overlap(self, lifetime1, lifetime2):
        start1, end1 = lifetime1
        start2, end2 = lifetime2
        return not (end1 <= start2 or end2 <= start1)

In [None]:
class MelonTrainer:
    def __init__(self, model, criterion, optimizer, device, memory_budget):
        self.model = model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.memory_budget = memory_budget
        self.has_bn = self._check_has_bn()
        self.memory_pool = self._initialize_memory_pool()

    def _check_has_bn(self):
        for module in self.model.modules():
            if isinstance(module, nn.BatchNorm2d):
                return True
        return False

    def _initialize_memory_pool(self):
        return LifetimeAwareMemoryPool(self.memory_budget)

    def train(self, train_loader):
        start_time = time.monotonic()
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            profile_memory=True,
            record_shapes=True
        ) as prof:
            for batch_idx, (inputs, labels) in enumerate(tqdm(train_loader, desc="Training")):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # Allocate memory for inputs and labels using the memory pool
                input_tensor_id = f"input_batch_{batch_idx}"
                label_tensor_id = f"label_batch_{batch_idx}"

                try:
                    self.memory_pool.allocate(input_tensor_id, inputs.element_size() * inputs.nelement(), (batch_idx, batch_idx + 1))
                    self.memory_pool.allocate(label_tensor_id, labels.element_size() * labels.nelement(), (batch_idx, batch_idx + 1))
                except MemoryError:
                    continue

                if self.has_bn:
                    loss, acc = self._train_step_with_recomputation(inputs, labels)
                else:
                    loss, acc = self._train_step_with_microbatch(inputs, labels)

                running_loss += loss
                correct += acc[0]
                total += acc[1]

                self.memory_pool.free(input_tensor_id)
                self.memory_pool.free(label_tensor_id)

        end_time = time.monotonic()
        epoch_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total if total > 0 else 0

        return epoch_loss, accuracy, start_time, end_time

    def _train_step_with_recomputation(self, inputs, labels):
        self.optimizer.zero_grad()

        with torch.no_grad():
            intermediate_outputs = []
            x = inputs

            x = self.model.conv1(x)
            x = self.model.bn1(x)
            x = self.model.relu(x)
            x = self.model.maxpool(x)

            for layer_name in ['layer1', 'layer2', 'layer3', 'layer4']:
                layer = getattr(self.model, layer_name)
                x = layer(x)
                if self.has_bn:
                    intermediate_outputs.append(x.detach())

        with record_function("forward_pass"):
            outputs = self.model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]

        with record_function("loss_computation"):
            loss = self.criterion(outputs, labels)

        with record_function("backward_pass"):
            loss.backward()

        with record_function("optimizer_step"):
            self.optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        total = labels.size(0)

        return loss.item(), (correct, total)

    def _calculate_micro_batch_size(self, input_size):
        tensor_size = input_size[1] * input_size[2] * input_size[3] * 4
        micro_batch_size = max(1, min(input_size[0], self.memory_budget // tensor_size))
        return micro_batch_size

In [None]:
def evaluate(model, data_loader, criterion, device, phase="Validation"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc=f"{phase}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs[0], labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs[0], 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(data_loader)
    accuracy = 100 * correct / total
    # print(f"{phase} Loss: {epoch_loss:.4f}, {phase} Accuracy: {accuracy:.2f}%")

    return epoch_loss, accuracy

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
torch.cuda.empty_cache()

In [None]:
print_gpu_utilization()

In [None]:
free_memory, total_memory = torch.cuda.mem_get_info()
print(f"Free memory: {free_memory / 1024**2:.2f} MB")
print(f"Total memory: {total_memory / 1024**2:.2f} MB")

In [None]:
trainer = MelonTrainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    memory_budget=4096 * 1024**2
)

EPOCHS = 10
best_valid_loss = float('inf')
total_time = 0
for epoch in range(EPOCHS):
    start_time = time.monotonic()
    train_loss, train_acc, start_time, end_time = trainer.train(train_iterator)
    print_gpu_utilization()
    free_memory, total_memory = torch.cuda.mem_get_info()
    print(f"Free memory: {free_memory / 1024**2:.2f} MB")
    print(f"Total memory: {total_memory / 1024**2:.2f} MB")
    end_time = time.monotonic()
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    total_time += end_time - start_time

    print(f'Epoch: {epoch+1:02} | Epoch Train Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc:.2f}%')

print("Train finished")
print_gpu_utilization()

In [None]:
print_gpu_utilization()

In [None]:
free_memory, total_memory = torch.cuda.mem_get_info()
print(f"Free memory: {free_memory / 1024**2:.2f} MB")
print(f"Total memory: {total_memory / 1024**2:.2f} MB")

In [None]:
print("ResNet18 with Melon")
print(f'Total Training Time: {int(total_time/60)}m {int(total_time%60)}s')

In [None]:
from torch import profiler

dummy_input = torch.randn(32, 3, 224, 224).cuda()

# Profiling inference
with profiler.profile(
    activities=[
       profiler.ProfilerActivity.CPU,
        profiler.ProfilerActivity.CUDA,  # Include if using GPU
    ],
    on_trace_ready=profiler.tensorboard_trace_handler("./logs"),  # Optional logging
    record_shapes=True,
    with_stack=True
) as prof:
    with torch.no_grad():
        model(dummy_input)


# Print results
print(prof.key_averages().table(sort_by="cuda_time_total" if torch.cuda.is_available() else "cpu_time_total", row_limit=50))