<a href="https://colab.research.google.com/github/kodenshacho/sigma/blob/master/upscale_pk.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.onnx
import cv2
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet

# --- „É¨„Ç§„É§„ÉºÂÆöÁæ©ÔºàFlatten ‚Üí FC ‚Üí ReshapeÔºâ ---
class IdentityFC(nn.Module):
    def __init__(self, shape, random_init=False):
        super().__init__()
        self.shape = shape  # (C, H, W)
        flat_dim = shape[0] * shape[1] * shape[2]
        self.fc = nn.Linear(flat_dim, flat_dim, bias=False)

        if random_init:
            self.init_random()
        else:
            self.init_identity(flat_dim)

    def init_identity(self, dim):
        with torch.no_grad():
            weight = torch.zeros((dim, dim))
            for i in range(dim):
                weight[i, i] = 1.0
            self.fc.weight.copy_(weight)

    def init_random(self):
        nn.init.kaiming_normal_(self.fc.weight, a=0.01)

    def forward(self, x):
        b, c, h, w = x.shape
        x_flat = x.view(b, -1)
        x_fc = self.fc(x_flat)
        return x_fc.view(b, c, h, w)

# --- Real-ESRGAN„É¢„Éá„É´„Å´FC„É¨„Ç§„É§„Éº„ÇíÊåøÂÖ• ---
def insert_fc_into_pretrained(model: nn.Module, random_init=False, verbose=True):
    dummy = torch.randn(1, 3, 1200, 1600)
    with torch.no_grad():
        x = model.conv_first(dummy)
        min_area = x.shape[2] * x.shape[3]
        min_idx = -1
        feature_maps = []

        for i, layer in enumerate(model.body):
            x = layer(x)
            area = x.shape[2] * x.shape[3]
            feature_maps.append(x)
            if area < min_area:
                min_area = area
                min_idx = i

    if verbose:
        print(f"üîç ÊúÄÂ∞èÁâπÂæ¥„Éû„ÉÉ„Éó‰ΩçÁΩÆ: model.body[{min_idx}]„ÄÅ„Çµ„Ç§„Ç∫: {feature_maps[min_idx].shape}")

    before = list(model.body.children())[:min_idx + 1]
    after = list(model.body.children())[min_idx + 1:]

    fc_layer = IdentityFC(shape=feature_maps[min_idx].shape[1:], random_init=random_init)
    model.body = nn.Sequential(*before, fc_layer, *after)
    return model

# --- ONNX„Éï„Ç°„Ç§„É´„Å´Â§âÊèõ ---
def export_to_onnx(model, input_tensor, onnx_path="exported_model.onnx"):
    model.eval()
    torch.onnx.export(model, input_tensor, onnx_path,
                      input_names=['input'], output_names=['output'],
                      dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
                      opset_version=11)
    print(f"‚úÖ ONNX„Å®„Åó„Å¶‰øùÂ≠ò„Åï„Çå„Åæ„Åó„Åü: {onnx_path}")

# --- „É¢„Éá„É´„ÇíÂ≠¶ÁøíÔºàrandom_init=True „ÅÆÂ†¥ÂêàÔºâ ---
def train_model(model, target_model, epochs=1, lr=1e-4):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    dummy_input = torch.randn(1, 3, 1200, 1600)

    with torch.no_grad():
        target_output = target_model(dummy_input)

    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(dummy_input)
        loss = loss_fn(output, target_output)
        loss.backward()
        optimizer.step()
        print(f"üß™ Epoch {epoch+1}/{epochs}, Loss: {loss.item():.6f}")

# --- GUI„ÅßÁµêÊûú„ÇíÊØîËºÉÔºàOpenCVÔºâ ---
def visualize_output(output1, output2):
    def tensor_to_cv(img):
        img = img.squeeze().permute(1, 2, 0).clamp(0, 1).cpu().numpy()
        img = (img * 255).astype(np.uint8)
        return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    img1 = tensor_to_cv(output1)
    img2 = tensor_to_cv(output2)
    combined = np.hstack((img1, img2))
    cv2.imshow('Â∑¶: Pretrained, Âè≥: FC‰ªò„Åç', combined)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

# --- „É°„Ç§„É≥Èñ¢Êï∞ ---
def main():
    # „É¢„Éá„É´Ë™≠„ÅøËæº„Åø
    model_path = 'pretrained/RealESRGAN_x1_fixed_1600x1200.pth'
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
                    num_block=23, num_grow_ch=32, scale=1)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()

    # FC‰ªò„Åç„É¢„Éá„É´‰ΩúÊàêÔºà„É©„É≥„ÉÄ„É† or „Ç¢„Ç§„Éá„É≥„ÉÜ„Ç£„ÉÜ„Ç£ÂàùÊúüÂåñÔºâ
    use_random_init = True  # ‚Üê True„ÅÆÂ†¥Âêà„ÄÅË®ìÁ∑¥„ÇÇË°å„Çè„Çå„Çã
    model_fc = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
                       num_block=23, num_grow_ch=32, scale=1)
    model_fc.load_state_dict(torch.load(model_path), strict=True)
    model_fc = insert_fc_into_pretrained(model_fc, random_init=use_random_init, verbose=True)

    # ÂÖ•Âäõ„Éá„Éº„ÇøÁîüÊàê
    input_img = torch.randn(1, 3, 1200, 1600)

    # Â≠¶ÁøíÔºàÂøÖË¶Å„Å™Â†¥Âêà„ÅÆ„ÅøÔºâ
    if use_random_init:
        print("‚öôÔ∏è FC„É¨„Ç§„É§„Éº„ÅÆ„É©„É≥„ÉÄ„É†ÂàùÊúüÂåñ„Å´ÂØæ„Åó„Å¶ÂæÆË™øÊï¥„ÇíË°å„ÅÑ„Åæ„Åô...")
        train_model(model_fc, model, epochs=3)

    # Âá∫ÂäõË®àÁÆó
    with torch.no_grad():
        out1 = model(input_img)
        out2 = model_fc(input_img)
        is_same = torch.allclose(out1, out2, atol=1e-6)
        print(f"‚úÖ Âá∫Âäõ‰∏ÄËá¥: {is_same}")

    # .pth„Éï„Ç°„Ç§„É´„Å®„Åó„Å¶‰øùÂ≠ò
    torch.save(model_fc.state_dict(), "modified_model_fc.pth")
    print("‚úÖ FC‰ªò„Åç„É¢„Éá„É´„Çí‰øùÂ≠ò„Åó„Åæ„Åó„Åü: modified_model_fc.pth")

    # ONNXÂΩ¢Âºè„Å®„Åó„Å¶‰øùÂ≠ò
    export_to_onnx(model_fc, input_img, onnx_path="modified_model_fc.onnx")

    # GUI„ÅßÂá∫ÂäõÁîªÂÉè„ÇíÊØîËºÉ
    visualize_output(out1, out2)

if __name__ == "__main__":
    main()

In [None]:
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from insert_fc import insert_fc_into_pretrained

def test_model_with_fc():
    # Â≠¶ÁøíÊ∏à„Åø„É¢„Éá„É´„ÅÆË™≠„ÅøËæº„Åø
    model_path = 'pretrained/RealESRGAN_x1_fixed_1600x1200.pth'
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
                    num_block=23, num_grow_ch=32, scale=1)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()

    # „É¢„Éá„É´„ÇíË§áË£Ω„Åó„Å¶„ÄÅFC„É¨„Ç§„É§„Éº„ÇíÊåøÂÖ•
    model_fc = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
                       num_block=23, num_grow_ch=32, scale=1)
    model_fc.load_state_dict(torch.load(model_path), strict=True)
    model_fc = insert_fc_into_pretrained(model_fc, random_init=False, verbose=True)
    model_fc.eval()

    # „ÉÜ„Çπ„ÉàÁî®ÁîªÂÉèÔºà„É©„É≥„ÉÄ„É†Ôºâ
    input_img = torch.randn(1, 3, 1200, 1600)

    # Âá∫Âäõ„ÇíÊØîËºÉ
    with torch.no_grad():
        out1 = model(input_img)
        out2 = model_fc(input_img)
        is_same = torch.allclose(out1, out2, atol=1e-6)

    print(f"‚úÖ Âá∫Âäõ„Åå‰∏ÄËá¥„Åô„Çã„Åã: {is_same}")

if __name__ == "__main__":
    test_model_with_fc()