In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install torch torchvision pillow flask pyngrok kagglehub

Collecting pyngrok
  Downloading pyngrok-7.2.8-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nv

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from flask import Flask, request, render_template_string, send_from_directory
from werkzeug.utils import secure_filename
from torchvision.utils import save_image
from torchvision import transforms
from pyngrok import ngrok

In [None]:
# ========== 1. SETUP DIRECTORIES ==========
os.makedirs('uploads', exist_ok=True)
os.makedirs('output', exist_ok=True)
os.makedirs('static', exist_ok=True)  # For your logo

In [None]:
# ========== 2. GAN+UNET MODEL ==========
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

# ==========  UNET ARCHITECTURE ==========
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# ==========  GENERATOR ==========
class UNetGenerator(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNetGenerator, self).__init__()
        self.unet = UNet(n_channels, n_classes)

    def forward(self, x):
        return self.unet(x)

# ==========  INITIALIZE MODEL ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNetGenerator(3, 3).to(device)
generator.load_state_dict(torch.load("/content/drive/MyDrive/gan_unet_checkpoints/generator_epoch_49.pth"))
generator.eval()

UNetGenerator(
  (unet): UNet(
    (inc): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (down1): Down(
      (maxpool_conv): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (double_conv): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), pa

In [None]:
# ========== 3. IMAGE PROCESSING ==========
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

def predict_image(image_path):
    """Process image through GAN+UNet"""
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)

    # Create mask (96x96 center gap)
    mask = torch.ones_like(img_tensor)
    gap_size = 96
    x = (img_tensor.shape[2] - gap_size) // 2
    y = (img_tensor.shape[3] - gap_size) // 2
    mask[:, :, x:x+gap_size, y:y+gap_size] = 0

    # Generate filled image
    with torch.no_grad():
        filled_img = generator(img_tensor * mask)

    # Save result
    output_path = os.path.join('output', os.path.basename(image_path))
    save_image(filled_img, output_path, normalize=True)
    return output_path

In [None]:
# ========== 4. FLASK WEB APP ==========
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['OUTPUT_FOLDER'] = 'output'


# YOUR MERGED HTML/CSS/JS
HTML_TEMPLATE = """
<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Image Inpainting</title>
  <style>
    body {
      font-family: Arial, sans-serif;
      margin: 0;
      padding: 0;
      background: #C1E1C1;
    }
    .logo {
      text-align: center;
      padding: 20px;
    }
    .logo img {
      height: 80px;
    }
    .title {
      text-align: center;
      color: #333;
      margin-bottom: 30px;
    }
    .container {
      display: flex;
      justify-content: center;
      gap: 50px;
      padding: 20px;
    }
    .box {
      background: white;
      border-radius: 8px;
      padding: 20px;
      box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
      width: 400px;
      text-align: center;
    }
    .label-text {
      display: block;
      margin-bottom: 15px;
      font-weight: bold;
      color: #444;
    }
    .image-box {
      width: 100%;
      height: 300px;
      object-fit: contain;
      border: 1px solid #ddd;
      border-radius: 4px;
      background: #f9f9f9;
    }
    #upload {
      display: none;
    }
    #processBtn {
      display: block;
      margin: 20px auto;
      padding: 10px 25px;
      background: #4CAF50;
      color: white;
      border: none;
      border-radius: 4px;
      cursor: pointer;
      font-size: 16px;
    }
    #processBtn:hover {
      background: #45a049;
    }
    #loading {
      text-align: center;
      font-size: 18px;
      color: #666;
    }
  </style>
</head>
<body>
  <div class="logo">
    <img src="/static/ulab2.webp" alt="Logo">
  </div>
  <h1 class="title">Image Inpainting</h1>
  <div class="container">
    <div class="box">
      <label for="upload" class="label-text">Upload Image</label>
      <img id="uploadedImage" src="" alt="" class="image-box">
      <input type="file" id="upload" accept="image/*">
    </div>
    <div class="box">
      <label class="label-text">Result Image</label>
      <img id="resultImage" src="" alt="" class="image-box">
    </div>
  </div>
  <button id="processBtn">Process Image</button>
  <div id="loading" style="display: none;">Processing...</div>
  <script>
    document.addEventListener("DOMContentLoaded", () => {
      const uploadInput = document.getElementById("upload");
      const processBtn = document.getElementById("processBtn");
      const uploadedImage = document.getElementById("uploadedImage");
      const resultImage = document.getElementById("resultImage");
      const loading = document.getElementById("loading");

      // Preview uploaded image
      uploadInput.addEventListener("change", (e) => {
        const file = e.target.files[0];
        if (file) {
          const reader = new FileReader();
          reader.onload = (event) => {
            uploadedImage.src = event.target.result;
            resultImage.src = "";
          };
          reader.readAsDataURL(file);
        }
      });

      // Process image
      processBtn.addEventListener("click", async () => {
        const file = uploadInput.files[0];
        if (!file) {
          alert("Please upload an image first!");
          return;
        }

        loading.style.display = "block";
        processBtn.disabled = true;

        const formData = new FormData();
        formData.append("file", file);

        try {
          const response = await fetch("/inpaint", {
            method: "POST",
            body: formData,
          });

          if (!response.ok) throw new Error("Failed to process image");

          const blob = await response.blob();
          resultImage.src = URL.createObjectURL(blob);
        } catch (error) {
          console.error("Error:", error);
          alert("Error: " + error.message);
        } finally {
          loading.style.display = "none";
          processBtn.disabled = false;
        }
      });
    });
  </script>
</body>
</html>
"""

@app.route('/')
def home():
    return render_template_string(HTML_TEMPLATE)

@app.route('/inpaint', methods=['POST'])
def inpaint():
    file = request.files['file']
    filename = secure_filename(file.filename)
    upload_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
    file.save(upload_path)
    output_path = predict_image(upload_path)
    return send_from_directory(app.config['OUTPUT_FOLDER'], os.path.basename(output_path))

In [None]:
# ========== 5. UPLOAD YOUR LOGO ==========
from google.colab import files
from IPython.display import display, HTML

display(HTML("""
<h3>Upload your logo (ulab2.webp):</h3>
<input type="file" id="fileupload">
<script>
document.getElementById("fileupload").onchange = function(e) {
    var reader = new FileReader();
    reader.onload = function() {
        var data = new Uint8Array(reader.result);
        var cmd = "open('/content/static/ulab2.webp', 'wb').write(" +
                 Array.from(data) + ")";
        var kernel = IPython.notebook.kernel;
        kernel.execute(cmd);
    };
    reader.readAsArrayBuffer(e.target.files[0]);
};
</script>
"""))



In [None]:
# ========== 6. RUN THE WEB APP ==========
def run_app():
    !pip install pyngrok -q
    !ngrok authtoken 2wdRi8L7TrkWJ65PzVun6PE4qj8_3H3iqzEvxZdpGT5J2GdGH  # 👈 Replace with your token

    ngrok_tunnel = ngrok.connect(5000)
    print(' * Public URL:', ngrok_tunnel.public_url)
    app.run(host='0.0.0.0', port=5000)

In [None]:
run_app()  # Start the web interface

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml
 * Public URL: https://6c16-34-142-219-111.ngrok-free.app
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.28.0.12:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
INFO:werkzeug:127.0.0.1 - - [14/May/2025 16:56:52] "GET / HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [14/May/2025 16:56:53] "GET /static/ulab2.webp HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [14/May/2025 16:56:54] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
