# SANet Training

Training SANet model trên Kaggle với T4 GPU


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import os
from PIL import Image, ImageFile
import random
import glob
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as F
import pickle
import shutil

CONTENT_FOLDER = "/kaggle/input/coco-2017-dataset/coco2017"
STYLE_FOLDER = "/kaggle/input/wikisample/wikiart_sampled"
VGG_PATH = "/kaggle/input/vggandatautils/vgg_normalised.pth"

BATCH_SIZE = 8
LR = 1e-4
LR_DECAY = 5e-5
MAX_ITER = 160000
SAVE_INTERVAL = 5000
LOG_INTERVAL = 1000
PATIENCE = 999999


## Model Architecture


In [None]:
def calc_mean_std(feat, eps=1e-5):
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def mean_variance_norm(feat):
    size = feat.size()
    mean, std = calc_mean_std(feat)
    normalized_feat = (feat - mean.expand(size)) / std.expand(size)
    return normalized_feat

decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()
)

class SANet(nn.Module):
    def __init__(self, in_planes):
        super(SANet, self).__init__()
        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
        self.sm = nn.Softmax(dim = -1)
        self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1))
        
    def forward(self, content, style):
        F = self.f(mean_variance_norm(content))
        G = self.g(mean_variance_norm(style))
        H = self.h(style)
        b, c, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)
        b, c, h, w = G.size()
        G = G.view(b, -1, w * h)
        S = torch.bmm(F, G)
        S = self.sm(S)
        b, c, h, w = H.size()
        H = H.view(b, -1, w * h)
        O = torch.bmm(H, S.permute(0, 2, 1))
        b, c, h, w = content.size()
        O = O.view(b, c, h, w)
        O = self.out_conv(O)
        O += content
        return O

class Transform(nn.Module):
    def __init__(self, in_planes):
        super(Transform, self).__init__()
        self.sanet4_1 = SANet(in_planes = in_planes)
        self.sanet5_1 = SANet(in_planes = in_planes)
        self.upsample5_1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.merge_conv_pad = nn.ReflectionPad2d((1, 1, 1, 1))
        self.merge_conv = nn.Conv2d(in_planes, in_planes, (3, 3))
        
    def forward(self, content4_1, style4_1, content5_1, style5_1):
        return self.merge_conv(self.merge_conv_pad(self.sanet4_1(content4_1, style4_1) + self.upsample5_1(self.sanet5_1(content5_1, style5_1))))

class Net(nn.Module):
    def __init__(self, encoder, decoder):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])
        self.enc_2 = nn.Sequential(*enc_layers[4:11])
        self.enc_3 = nn.Sequential(*enc_layers[11:18])
        self.enc_4 = nn.Sequential(*enc_layers[18:31])
        self.enc_5 = nn.Sequential(*enc_layers[31:44])
        self.transform = Transform(in_planes = 512)
        self.decoder = decoder
        self.mse_loss = nn.MSELoss()
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    def encode_with_intermediate(self, input):
        results = [input]
        for i in range(5):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
        return results[1:]

    def calc_content_loss(self, input, target, norm = False):
        if(norm == False):
          return self.mse_loss(input, target)
        else:
          return self.mse_loss(mean_variance_norm(input), mean_variance_norm(target))

    def calc_style_loss(self, input, target):
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + self.mse_loss(input_std, target_std)
    
    def forward(self, content, style):
        style_feats = self.encode_with_intermediate(style)
        content_feats = self.encode_with_intermediate(content)
        stylized = self.transform(content_feats[3], style_feats[3], content_feats[4], style_feats[4])
        g_t = self.decoder(stylized)
        g_t_feats = self.encode_with_intermediate(g_t)
        loss_c = self.calc_content_loss(g_t_feats[3], content_feats[3], norm = True) + self.calc_content_loss(g_t_feats[4], content_feats[4], norm = True)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 5):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        Icc = self.decoder(self.transform(content_feats[3], content_feats[3], content_feats[4], content_feats[4]))
        Iss = self.decoder(self.transform(style_feats[3], style_feats[3], style_feats[4], style_feats[4]))
        l_identity1 = self.calc_content_loss(Icc, content) + self.calc_content_loss(Iss, style)
        Fcc = self.encode_with_intermediate(Icc)
        Fss = self.encode_with_intermediate(Iss)
        l_identity2 = self.calc_content_loss(Fcc[0], content_feats[0]) + self.calc_content_loss(Fss[0], style_feats[0])
        for i in range(1, 5):
            l_identity2 += self.calc_content_loss(Fcc[i], content_feats[i]) + self.calc_content_loss(Fss[i], style_feats[i])
        return loss_c, loss_s, l_identity1, l_identity2


## Data Utils


In [None]:
class TransformImageNet:
    def __init__(self, target_long=512, min_short=256, crop_size=None, 
                 gray_ratio=0.0, use_normalize=True):
        self.target_long = target_long
        self.min_short = min_short
        self.crop_size = crop_size
        self.gray_ratio = gray_ratio
        self.use_normalize = use_normalize
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])

    def resize_and_pad(self, img):
        w, h = img.size
        if w > h:
            new_w = self.target_long
            new_h = int(h * self.target_long / w)
        else:
            new_h = self.target_long
            new_w = int(w * self.target_long / h)
        img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
        pad_w = max(0, self.min_short - new_w)
        pad_h = max(0, self.min_short - new_h)
        if pad_w > 0 or pad_h > 0:
            img = F.pad(img, (0,0,pad_w,pad_h), fill=0)
        return img

    def __call__(self, img):
        if random.random() < self.gray_ratio:
            img = img.convert("L").convert("RGB")
        img = self.resize_and_pad(img)
        img = T.RandomHorizontalFlip(p=0.5)(img)
        if self.crop_size:
            img = T.RandomCrop(self.crop_size)(img)
        img = self.to_tensor(img)
        if self.use_normalize:
            img = self.normalize(img)
        return img

def InfiniteSampler(n):
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

class CustomImageDataset(Dataset):
    def __init__(self, content_folder, style_folder, 
                 content_subset, style_subset,
                 transform=None, gray_ratio=0.0,
                 valid_ext=('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
        self.content_folder = os.path.join(content_folder, content_subset)
        self.style_folder = os.path.join(style_folder, style_subset)

        self.content_files = []
        self.style_files = []

        for ext in valid_ext:
            self.content_files.extend(glob.glob(os.path.join(self.content_folder, f"*{ext}")))
            self.style_files.extend(glob.glob(os.path.join(self.style_folder, f"*{ext}")))

        self.content_files = sorted(self.content_files)
        self.style_files = sorted(self.style_files)

        if len(self.content_files) == 0:
            raise RuntimeError(f"No content images found in {self.content_folder}")
        if len(self.style_files) == 0:
            raise RuntimeError(f"No style images found in {self.style_folder}")

        self.transform = transform
        self.gray_ratio = gray_ratio

    def __len__(self):
        return len(self.content_files)

    def __getitem__(self, idx):
        content_path = self.content_files[idx]
        content_img = Image.open(content_path).convert("RGB")
        
        style_path = random.choice(self.style_files)
        style_img = Image.open(style_path).convert("RGB")
        
        if self.transform:
            content_img = self.transform(content_img)
            style_img = self.transform(style_img)

        return content_img, style_img


# Extract Checkpoint

In [None]:
checkpoint_input = "/kaggle/input/vggandresult"

if os.path.exists(f"{checkpoint_input}/latest_checkpoint.pth"):
    print("Copying checkpoint files to /kaggle/working/...")
    
    shutil.copy(f"{checkpoint_input}/latest_checkpoint.pth", "/kaggle/working/")
    shutil.copy(f"{checkpoint_input}/decoder_best.pth", "/kaggle/working/")
    shutil.copy(f"{checkpoint_input}/transformer_best.pth", "/kaggle/working/")
    
    if os.path.exists(f"{checkpoint_input}/checkpoints"):
        shutil.copytree(f"{checkpoint_input}/checkpoints", "/kaggle/working/checkpoints")
    
    if os.path.exists(f"{checkpoint_input}/samples"):
        shutil.copytree(f"{checkpoint_input}/samples", "/kaggle/working/samples")
    
    print("Checkpoint extracted! Training will resume automatically.")
    print(f"Files copied:")
    print(f"  - latest_checkpoint.pth")
    print(f"  - decoder_best.pth")
    print(f"  - transformer_best.pth")
    print(f"  - checkpoints/ folder")
    print(f"  - samples/ folder")
else:
    print("No checkpoint found. Starting from scratch.")

## Dataset và DataLoader Setup


In [None]:
torch.backends.cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

train_transform = TransformImageNet(
    target_long=512,
    min_short=256,
    crop_size=256,
    gray_ratio=0.0,
    use_normalize=False
)

val_transform = TransformImageNet(
    target_long=512,
    min_short=256,
    crop_size=256,
    gray_ratio=0.0,
    use_normalize=False
)

train_dataset = CustomImageDataset(
    CONTENT_FOLDER,
    STYLE_FOLDER,
    content_subset="train2017",
    style_subset="train",
    transform=train_transform,
    gray_ratio=0.0
)

val_dataset = CustomImageDataset(
    CONTENT_FOLDER,
    STYLE_FOLDER,
    content_subset="val2017",
    style_subset="valid",
    transform=val_transform,
    gray_ratio=0.0
)

train_iter = iter(DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=InfiniteSamplerWrapper(train_dataset),
    num_workers=2,
    pin_memory=True
))

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train dataset: {len(train_dataset)} content images, {len(train_dataset.style_files)} style images")
print(f"Validation dataset: {len(val_dataset)} content images, {len(val_dataset.style_files)} style images")


## Training Setup


In [None]:
device = torch.device("cuda")

vgg.load_state_dict(torch.load(VGG_PATH))
vgg = nn.Sequential(*list(vgg.children())[:44])

network = Net(vgg, decoder)
network.to(device)
network.train()

optimizer = torch.optim.Adam([
    {'params': network.decoder.parameters()},
    {'params': network.transform.parameters()}
], lr=LR)

scaler = GradScaler()

print("Model initialized and moved to GPU")
print("Mixed Precision Training enabled")


## Inference & Testing Functions


In [None]:
def style_transfer_sanet(network, content_path, style_path, device, output_size=None):
    import torchvision.utils as vutils
    
    network.eval()
    
    content_img = Image.open(content_path).convert("RGB")
    style_img = Image.open(style_path).convert("RGB")
    
    if output_size:
        transform = T.Compose([
            T.Resize(output_size),
            T.ToTensor()
        ])
    else:
        transform = T.ToTensor()
    
    content = transform(content_img).unsqueeze(0).to(device)
    style = transform(style_img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        style_feats = network.encode_with_intermediate(style)
        content_feats = network.encode_with_intermediate(content)
        stylized = network.transform(
            content_feats[3], style_feats[3],
            content_feats[4], style_feats[4]
        )
        output = network.decoder(stylized)
    
    return output.clamp(0, 1)

def test_sanet(network, content_dir, style_dir, device, checkpoint_path, save_dir="/kaggle/working/test_results", num_samples=10):
    import torchvision.utils as vutils
    
    if not os.path.exists(checkpoint_path):
        print(f"Error: Checkpoint not found at {checkpoint_path}")
        print("Available checkpoints:")
        checkpoint_dir = os.path.dirname(checkpoint_path)
        if os.path.exists(checkpoint_dir):
            for f in os.listdir(checkpoint_dir):
                if f.endswith('.pth'):
                    print(f"  - {os.path.join(checkpoint_dir, f)}")
        return
    
    os.makedirs(save_dir, exist_ok=True)
    
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    network.decoder.load_state_dict(checkpoint['decoder'])
    network.transform.load_state_dict(checkpoint['transform'])
    network.to(device)
    network.eval()
    
    content_paths = sorted(glob.glob(os.path.join(content_dir, "*.jpg")))[:num_samples]
    style_paths = sorted(glob.glob(os.path.join(style_dir, "*.jpg")))
    
    print(f"Testing {len(content_paths)} content images...")
    
    with torch.no_grad():
        for idx, content_path in enumerate(content_paths):
            style_path = random.choice(style_paths)
            
            pair_dir = os.path.join(save_dir, f"pair_{idx+1}")
            os.makedirs(pair_dir, exist_ok=True)
            
            output = style_transfer_sanet(network, content_path, style_path, device, output_size=512)
            
            content_img = Image.open(content_path).convert("RGB")
            style_img = Image.open(style_path).convert("RGB")
            
            transform = T.Compose([T.Resize(512), T.ToTensor()])
            content_tensor = transform(content_img).unsqueeze(0)
            style_tensor = transform(style_img).unsqueeze(0)
            
            vutils.save_image(content_tensor, os.path.join(pair_dir, "content.jpg"))
            vutils.save_image(style_tensor, os.path.join(pair_dir, "style.jpg"))
            vutils.save_image(output, os.path.join(pair_dir, "output.jpg"))
    
    print(f"Testing completed! Results saved in {save_dir}")

print("Inference functions ready")


## Training Loop với Early Stopping


In [None]:
def adjust_learning_rate(optimizer, iteration):
    lr = LR / (1.0 + LR_DECAY * iteration)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

def save_checkpoint(iteration, val_loss, best_loss, patience_counter, is_best=False):
    os.makedirs("/kaggle/working/checkpoints", exist_ok=True)
    
    checkpoint = {
        'iteration': iteration,
        'decoder': network.decoder.state_dict(),
        'transform': network.transform.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scaler': scaler.state_dict(),
        'val_loss': val_loss,
        'best_loss': best_loss,
        'patience_counter': patience_counter
    }
    
    torch.save(checkpoint, f"/kaggle/working/checkpoints/checkpoint_iter_{iteration}.pth")
    torch.save(checkpoint, "/kaggle/working/latest_checkpoint.pth")
    
    if is_best:
        torch.save(network.decoder.state_dict(), "/kaggle/working/decoder_best.pth")
        torch.save(network.transform.state_dict(), "/kaggle/working/transformer_best.pth")

def load_checkpoint(checkpoint_path):
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        network.decoder.load_state_dict(checkpoint['decoder'])
        network.transform.load_state_dict(checkpoint['transform'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])
        start_iter = checkpoint['iteration']
        best_loss = checkpoint['best_loss']
        patience_counter = checkpoint['patience_counter']
        print(f"Resume from iteration {start_iter} | Best loss: {best_loss:.4f} | Patience: {patience_counter}/{PATIENCE}")
        return start_iter, best_loss, patience_counter
    return 0, float('inf'), 0

def validate(val_loader, network, device):
    network.eval()
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for content_images, style_images in val_loader:
            content_images = content_images.to(device)
            style_images = style_images.to(device)
            
            loss_c, loss_s, l_identity1, l_identity2 = network(content_images, style_images)
            loss_c = 1.0 * loss_c
            loss_s = 3.0 * loss_s
            loss = loss_c + loss_s + 50 * l_identity1 + l_identity2
            
            total_loss += loss.item()
            num_batches += 1
            
            if num_batches >= 50:
                break
    
    network.train()
    return total_loss / num_batches

def generate_sample(network, val_loader, device, iteration):
    network.eval()
    os.makedirs("/kaggle/working/samples", exist_ok=True)
    
    with torch.no_grad():
        sample_content, sample_style = next(iter(val_loader))
        sample_content = sample_content[:1].to(device)
        sample_style = sample_style[:1].to(device)
        
        style_feats = network.encode_with_intermediate(sample_style)
        content_feats = network.encode_with_intermediate(sample_content)
        stylized = network.transform(
            content_feats[3], style_feats[3],
            content_feats[4], style_feats[4]
        )
        output = network.decoder(stylized).clamp(0, 1)
        
        import torchvision.utils as vutils
        vutils.save_image(output, f"/kaggle/working/samples/output_{iteration}.jpg")
        vutils.save_image(sample_content, f"/kaggle/working/samples/content_{iteration}.jpg")
        vutils.save_image(sample_style, f"/kaggle/working/samples/style_{iteration}.jpg")
    
    network.train()

start_iter, best_loss, patience_counter = load_checkpoint("/kaggle/working/latest_checkpoint.pth")

history = {
    'iterations': [],
    'train_loss': [],
    'train_loss_c': [],
    'train_loss_s': [],
    'train_loss_id1': [],
    'train_loss_id2': [],
    'val_loss': [],
    'lr': []
}

print("Starting training...")
print(f"Max iterations: {MAX_ITER}")
print(f"Save interval: {SAVE_INTERVAL}")
print(f"Log interval: {LOG_INTERVAL}")
print(f"Patience: {PATIENCE}")
print("-" * 50)

for i in range(start_iter, MAX_ITER):
    current_lr = adjust_learning_rate(optimizer, i)
    
    content_images, style_images = next(train_iter)
    content_images = content_images.to(device)
    style_images = style_images.to(device)
    
    loss_c, loss_s, l_identity1, l_identity2 = network(content_images, style_images)
    loss_c = 1.0 * loss_c
    loss_s = 3.0 * loss_s
    loss = loss_c + loss_s + 50 * l_identity1 + l_identity2
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (i + 1) % LOG_INTERVAL == 0:
        history['iterations'].append(i + 1)
        history['train_loss'].append(loss.item())
        history['train_loss_c'].append(loss_c.item())
        history['train_loss_s'].append(loss_s.item())
        history['train_loss_id1'].append(l_identity1.item())
        history['train_loss_id2'].append(l_identity2.item())
        history['lr'].append(current_lr)
        
        print(f"Iter: {i+1}/{MAX_ITER} | Loss: {loss.item():.4f} | Content: {loss_c.item():.4f} | Style: {loss_s.item():.4f} | Id1: {l_identity1.item():.4f} | Id2: {l_identity2.item():.4f}")
    
    if (i + 1) % 10 == 0:
        torch.cuda.empty_cache()
    
    if (i + 1) % SAVE_INTERVAL == 0:
        val_loss = validate(val_loader, network, device)
        history['val_loss'].append(val_loss)
        is_best = val_loss < best_loss
        
        if is_best:
            best_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        save_checkpoint(i + 1, val_loss, best_loss, patience_counter, is_best)
        generate_sample(network, val_loader, device, i + 1)
        print(f"Checkpoint saved at {i+1} | Val Loss: {val_loss:.4f} | Best: {is_best} | BestLoss: {best_loss:.4f} | Patience: {patience_counter}/{PATIENCE}")
        
        if patience_counter >= PATIENCE:
            print(f"Early stopping triggered at iteration {i+1}")
            print(f"Best validation loss: {best_loss:.4f}")
            break

print("-" * 50)
print("Training completed!")
print(f"Final best loss: {best_loss:.4f}")

with open('/kaggle/working/training_history.pkl', 'wb') as f:
    pickle.dump(history, f)
print("Training history saved to training_history.pkl")
