In [14]:
import os, io, base64, tempfile
from flask import Flask, request, jsonify, render_template
import torch, torch.nn as nn
import numpy as np
import cv2
from PIL import Image
from torchvision import models
from torch.nn import TransformerEncoder, TransformerEncoderLayer

# ─── CONFIG ─────────────────────────────────────────────
DEVICE          = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_PATH = "models/resnet_transformer.pth"
NUM_BINS        = 5
IMG_H, IMG_W    = 180, 240

# ─── MODEL ──────────────────────────────────────────────
class ResNetTransformer(nn.Module):
    def __init__(self, num_classes, num_bins):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        resnet.conv1 = nn.Conv2d(2,64,7,2,3,bias=False)
        resnet.fc    = nn.Identity()
        self.backbone  = resnet
        d_model        = 2048
        self.pos_embed = nn.Parameter(torch.zeros(1,num_bins,d_model))
        enc_layer      = TransformerEncoderLayer(d_model,8,2048,dropout=0.1)
        self.transformer = TransformerEncoder(enc_layer,num_layers=2)
        self.classifier  = nn.Linear(d_model, num_classes)

    def forward(self, x):
        # x: [B,K,2,H,W]
        B,K,C,H,W = x.shape
        x = x.view(B*K, C, H, W)
        f = self.backbone(x)                    # [B*K,2048]
        f = f.view(B,K,-1) + self.pos_embed     # [B,K,2048]
        t = f.permute(1,0,2)                    # [K,B,2048]
        out = self.transformer(t).mean(0)       # [B,2048]
        return self.classifier(out)             # [B,num_classes]

# ─── LOAD MODEL ──────────────────────────────────────────
ckpt       = torch.load(CHECKPOINT_PATH, map_location=DEVICE, weights_only=True)
label_map  = ckpt.get("label_map", {i:f"class_{i}" for i in range(7)})
state_dict = ckpt.get("model_state_dict", ckpt)
model = ResNetTransformer(len(label_map), NUM_BINS).to(DEVICE)
model.load_state_dict(state_dict, strict=False)
model.eval()

# ─── HELPERS ─────────────────────────────────────────────
def frame_to_voxel(img_pil):
    arr = np.array(img_pil.resize((IMG_W,IMG_H))).astype(np.float32)
    arr = (arr - arr.min())/(arr.max()-arr.min()+1e-6)
    vox = np.stack([arr,arr],axis=0)[None]      # [1,2,H,W]
    vox = np.repeat(vox, NUM_BINS, axis=0)      # [K,2,H,W]
    return torch.from_numpy(vox).unsqueeze(0).to(DEVICE)  # [1,K,2,H,W]

def extract_frame(video_path, timestamp=None):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError("cannot open video")
    fps   = cap.get(cv2.CAP_PROP_FPS) or 25.0
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if timestamp is None:
        # default: middle of clip
        frame_no = frame_count // 2
    else:
        frame_no = int(timestamp * fps)
        frame_no = max(0, min(frame_no, frame_count-1))
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        raise RuntimeError(f"couldn't read frame {frame_no}")
    # convert BGR → grayscale PIL
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    return Image.fromarray(gray)

# ─── FLASK APP ───────────────────────────────────────────
app = Flask(__name__)

@app.route("/")
def index():
    return render_template("index.html")

@app.route("/predict_video", methods=["POST"])
def predict_video():
    """
    Expects a multipart/form-data:
      - 'video' : the uploaded file
      - 'time'  : optional float seconds at which to extract frame
    """
    if 'video' not in request.files:
        return jsonify(error="no video file"), 400

    vid = request.files['video']
    ts  = request.form.get('time', type=float, default=None)

    # save to temp file
    suffix = os.path.splitext(vid.filename)[1]
    with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
        vid.save(tmp.name)
        tmp_path = tmp.name

    try:
        frame = extract_frame(tmp_path, timestamp=ts)
        voxel = frame_to_voxel(frame)                 # [1,K,2,H,W]
        with torch.no_grad():
            logits = model(voxel)
            probs  = torch.softmax(logits,1)[0].cpu().numpy()
            idx    = int(probs.argmax())
    finally:
        os.unlink(tmp_path)

    return jsonify({
        "emotion":     label_map[idx],
        "confidence":  float(probs[idx]),
        "all_probs":   { label_map[i]: float(p) for i,p in enumerate(probs) }
    })

if __name__=="__main__":
    app.run(port=5870)


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5870
Press CTRL+C to quit
127.0.0.1 - - [20/Apr/2025 07:18:11] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Apr/2025 07:18:11] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [20/Apr/2025 07:18:22] "POST /predict_video HTTP/1.1" 200 -
