In [None]:
from pathlib import Path
import os
import getpass
import shutil
import subprocess

REPO_DIR = Path('/content/AnomalyDetection')
REPO_URL = os.environ.get('ANOMALY_REPO_URL', 'https://github.com/kh87joo2/AnomalyDetection.git')
FORCE_RECLONE = os.environ.get('ANOMALY_FORCE_RECLONE', '0') == '1'


def is_repo_root(path: Path) -> bool:
    required = [
        path / 'requirements.txt',
        path / 'configs',
        path / 'trainers',
        path / 'configs' / 'patchtst_ssl.yaml',
        path / 'configs' / 'swinmae_ssl.yaml',
    ]
    return all(p.exists() for p in required)


def clone_repo(repo_url: str, token: str = '') -> subprocess.CompletedProcess:
    clone_url = repo_url
    if token and repo_url.startswith('https://'):
        clone_url = repo_url.replace('https://', f'https://{token}@', 1)
    return subprocess.run(['git', 'clone', clone_url, str(REPO_DIR)], text=True, capture_output=True)


if is_repo_root(REPO_DIR):
    print(f'[info] using existing repo: {REPO_DIR}')
elif REPO_DIR.exists():
    if FORCE_RECLONE:
        print(f'[warn] removing existing directory because ANOMALY_FORCE_RECLONE=1: {REPO_DIR}')
        shutil.rmtree(REPO_DIR)
    else:
        raise RuntimeError(
            f'{REPO_DIR} exists but is not a valid repo root.\n'
            'Set ANOMALY_FORCE_RECLONE=1 to allow removal and reclone, or fix the directory manually.'
        )

if not is_repo_root(REPO_DIR):
    print(f'[info] cloning repo: {REPO_URL}')
    token = os.environ.get('ANOMALY_GH_TOKEN', '').strip()
    result = clone_repo(REPO_URL, token=token)

    if result.returncode != 0 and not token:
        print('[warn] public clone failed. If repo is private, enter a GitHub PAT.')
        token = getpass.getpass('GitHub PAT (private repo only): ').strip()
        if token:
            result = clone_repo(REPO_URL, token=token)

    if result.returncode != 0:
        if result.stderr:
            print('[git]', result.stderr.strip().splitlines()[-1])
        raise RuntimeError('git clone failed. Check repo URL, network, and token permissions.')

if not is_repo_root(REPO_DIR):
    raise FileNotFoundError('Repo root validation failed after clone.')

os.chdir(REPO_DIR)
print('cwd:', Path.cwd())
print('requirements.txt exists:', Path('requirements.txt').exists())
print('configs exists:', Path('configs').exists())
print('trainers exists:', Path('trainers').exists())


In [None]:
from pathlib import Path
import subprocess
import sys
import shlex


def run(cmd):
    cmd = [str(x) for x in cmd]
    print('$', ' '.join(shlex.quote(x) for x in cmd))
    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
    )

    assert proc.stdout is not None
    for line in proc.stdout:
        print(line, end='')

    code = proc.wait()
    if code != 0:
        raise RuntimeError(f"Command failed ({code}): {' '.join(cmd)}")


req = Path('requirements.txt')
if not req.exists():
    raise FileNotFoundError(
        f"requirements.txt not found in cwd={Path.cwd()}. Run bootstrap cell first or fix repo path."
    )

run([sys.executable, '-m', 'pip', 'install', '-U', 'pip'])
run([sys.executable, '-m', 'pip', 'install', '-r', str(req)])


In [None]:
import torch

print('torch:', torch.__version__)
print('cuda_available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('cuda_device_count:', torch.cuda.device_count())
    print('cuda_device_0:', torch.cuda.get_device_name(0))


In [None]:
import subprocess
import sys

cmd = [sys.executable, '-m', 'trainers.train_swinmae_ssl', '--config', 'configs/swinmae_ssl.yaml']
print('$', ' '.join(cmd))
result = subprocess.run(cmd)
if result.returncode != 0:
    raise RuntimeError(f"Training failed with exit code {result.returncode}")


In [None]:
from pathlib import Path

checkpoint_path = Path('checkpoints/swinmae_ssl.pt')
print('checkpoint_exists:', checkpoint_path.exists(), checkpoint_path)
assert checkpoint_path.exists(), f'Missing checkpoint: {checkpoint_path}'
print('checkpoint_size_bytes:', checkpoint_path.stat().st_size)
