In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import tifffile as tiff      # pip install tifffile
import numpy as np

class FlameSeqDataset(Dataset):
    """
    Returns (seq_len, 1, H, W) tensor of float32 temperatures (°C) and a label.
    Label is 1 for 'Fire', 0 for 'No Fire'.
    """
    def __init__(self, root, seq_len=5, downscale=1.0, transform=None):
        self.seq_len = seq_len
        self.transform = transform
        self.samples = []

        root = Path(root)
        for cls_name in ["Fire", "No Fire"]:
            label = 1 if cls_name == "Fire" else 0
            tiff_dir = root / cls_name / "Thermal" / "Celsius TIFF"
            paths = sorted(tiff_dir.glob("*.TIFF"))
            # sliding window
            for i in range(len(paths) - seq_len + 1):
                self.samples.append((paths[i:i + seq_len], label))

        self.downscale = downscale     # e.g. 0.5 to halve width and height

    def _read_tiff(self, path):
        arr = tiff.imread(str(path)).astype(np.float32)     # shape (H, W)
        if self.downscale != 1.0:
            # cheap shrink to cut memory, keep it simple
            h, w = arr.shape
            arr = arr[:: int(1 / self.downscale), :: int(1 / self.downscale)]
        # add channel axis
        arr = torch.from_numpy(arr)[None]    # (1, H, W)
        # simple 0-1 scaling; adapt if sensor limits differ
        arr = (arr - 0.0) / 1000.0
        return arr

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        paths, label = self.samples[idx]
        frames = [self._read_tiff(p) for p in paths]
        seq = torch.stack(frames)            # (seq_len, 1, H, W)
        if self.transform:
            seq = self.transform(seq)
        return seq, torch.tensor(label, dtype=torch.float32)

# -----------------------------------------------------------------------------
# CNN encoder — keep it tiny so LSTM sees small vectors.
class CNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )

    def forward(self, x):                    # (B, 1, H, W)
        x = self.conv(x)                     # (B, 64, 1, 1)
        return x.view(x.size(0), -1)         # (B, 64)

# -----------------------------------------------------------------------------
class CNNLSTM(nn.Module):
    def __init__(self, hidden=128):
        super().__init__()
        self.encoder = CNNEncoder()
        self.lstm = nn.LSTM(input_size=64, hidden_size=hidden,
                            batch_first=True)
        self.fc = nn.Linear(hidden, 1)

    def forward(self, x):                    # (B, T, 1, H, W)
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feats = self.encoder(x)              # (B*T, 64)
        feats = feats.view(B, T, -1)         # (B, T, 64)
        lstm_out, _ = self.lstm(feats)
        logits = self.fc(lstm_out[:, -1])    # last timestep
        return logits.squeeze(1)

# -----------------------------------------------------------------------------
if __name__ == "__main__":
    root = "CVSubset/FLAME 3 CV Dataset (Sycan Marsh)"
    dataset = FlameSeqDataset(root, seq_len=5)
    loader = DataLoader(dataset, batch_size=8, shuffle=True,
                        num_workers=4, pin_memory=True)

    model = CNNLSTM()
    criterion = nn.BCEWithLogitsLoss()
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(5):
        for seq, label in loader:
            optimiser.zero_grad()
            logit = model(seq)
            loss = criterion(logit, label)
            loss.backward()
            optimiser.step()
        print(f"epoch {epoch} loss {loss.item():.4f}")
