In [None]:
from pathlib import Path
import os
import getpass
import shutil
import subprocess

REPO_DIR = Path('/content/AnomalyDetection')
REPO_URL = os.environ.get('ANOMALY_REPO_URL', 'https://github.com/kh87joo2/AnomalyDetection.git')
FORCE_RECLONE = os.environ.get('ANOMALY_FORCE_RECLONE', '0') == '1'


def is_repo_root(path: Path) -> bool:
    required = [
        path / 'requirements.txt',
        path / 'configs',
        path / 'trainers',
        path / 'configs' / 'patchtst_ssl.yaml',
        path / 'configs' / 'swinmae_ssl.yaml',
    ]
    return all(p.exists() for p in required)


def clone_repo(repo_url: str, token: str = '') -> subprocess.CompletedProcess:
    clone_url = repo_url
    if token and repo_url.startswith('https://'):
        clone_url = repo_url.replace('https://', f'https://{token}@', 1)
    return subprocess.run(['git', 'clone', clone_url, str(REPO_DIR)], text=True, capture_output=True)


if is_repo_root(REPO_DIR):
    print(f'[info] using existing repo: {REPO_DIR}')
elif REPO_DIR.exists():
    if FORCE_RECLONE:
        print(f'[warn] removing existing directory because ANOMALY_FORCE_RECLONE=1: {REPO_DIR}')
        shutil.rmtree(REPO_DIR)
    else:
        raise RuntimeError(
            f'{REPO_DIR} exists but is not a valid repo root.\n'
            'Set ANOMALY_FORCE_RECLONE=1 to allow removal and reclone, or fix the directory manually.'
        )

if not is_repo_root(REPO_DIR):
    print(f'[info] cloning repo: {REPO_URL}')
    token = os.environ.get('ANOMALY_GH_TOKEN', '').strip()
    result = clone_repo(REPO_URL, token=token)

    if result.returncode != 0 and not token:
        print('[warn] public clone failed. If repo is private, enter a GitHub PAT.')
        token = getpass.getpass('GitHub PAT (private repo only): ').strip()
        if token:
            result = clone_repo(REPO_URL, token=token)

    if result.returncode != 0:
        if result.stderr:
            print('[git]', result.stderr.strip().splitlines()[-1])
        raise RuntimeError('git clone failed. Check repo URL, network, and token permissions.')

if not is_repo_root(REPO_DIR):
    raise FileNotFoundError('Repo root validation failed after clone.')

os.chdir(REPO_DIR)
print('cwd:', Path.cwd())
print('requirements.txt exists:', Path('requirements.txt').exists())
print('configs exists:', Path('configs').exists())
print('trainers exists:', Path('trainers').exists())


In [None]:
# KAGGLE_DATA_DOWNLOAD_SWINMAE
from pathlib import Path
import json
import os
import re
import subprocess
import sys
import zipfile

import numpy as np
import pandas as pd
import yaml

REPO_DIR = Path('/content/AnomalyDetection')
RAW_DIR = REPO_DIR / 'data' / 'raw' / 'vibration'
OUT_DIR = REPO_DIR / 'data' / 'vib'
DATASET = 'mohdsufianbinothman/triaxial-bearing-dataset'

RAW_DIR.mkdir(parents=True, exist_ok=True)
OUT_DIR.mkdir(parents=True, exist_ok=True)


def run(cmd):
    cmd = [str(x) for x in cmd]
    print('$', ' '.join(cmd))
    subprocess.run(cmd, check=True)


run([sys.executable, '-m', 'pip', 'install', '-q', 'kaggle', 'pandas', 'numpy', 'pyyaml'])

kaggle_dir = Path('/root/.kaggle')
kaggle_dir.mkdir(parents=True, exist_ok=True)
kaggle_json = kaggle_dir / 'kaggle.json'

if not kaggle_json.exists():
    user = os.environ.get('KAGGLE_USERNAME', '').strip()
    key = os.environ.get('KAGGLE_KEY', '').strip()
    if user and key:
        kaggle_json.write_text(json.dumps({'username': user, 'key': key}), encoding='utf-8')
    else:
        from google.colab import files

        print('Upload kaggle.json from https://www.kaggle.com/settings/account')
        uploaded = files.upload()
        if 'kaggle.json' not in uploaded:
            raise FileNotFoundError('kaggle.json not uploaded')
        with kaggle_json.open('wb') as f:
            f.write(uploaded['kaggle.json'])

os.chmod(kaggle_json, 0o600)

run(['kaggle', 'datasets', 'download', '-d', DATASET, '-p', str(RAW_DIR), '--force'])

for zip_path in sorted(RAW_DIR.glob('*.zip')):
    print('[extract]', zip_path)
    with zipfile.ZipFile(zip_path, 'r') as zf:
        zf.extractall(RAW_DIR)

csv_files = sorted(RAW_DIR.rglob('*.csv'))
npy_files = sorted(RAW_DIR.rglob('*.npy'))
print('csv_found:', len(csv_files), 'npy_found:', len(npy_files))

for p in csv_files[:20]:
    print('-', p)
for p in npy_files[:20]:
    print('-', p)


def norm(name: str) -> str:
    return re.sub(r'[^a-z0-9]+', '_', name.strip().lower()).strip('_')


def detect_axis_map(columns):
    normalized = {norm(c): c for c in columns}
    direct = {'x': None, 'y': None, 'z': None}

    for axis in ['x', 'y', 'z']:
        if axis in normalized:
            direct[axis] = normalized[axis]

    if all(direct.values()):
        return direct

    patterns = {
        'x': ['x', 'acc_x', 'axis_x', 'x_axis', 'ax', 'vib_x'],
        'y': ['y', 'acc_y', 'axis_y', 'y_axis', 'ay', 'vib_y'],
        'z': ['z', 'acc_z', 'axis_z', 'z_axis', 'az', 'vib_z'],
    }
    out = {'x': None, 'y': None, 'z': None}
    for axis, keys in patterns.items():
        for k in keys:
            if k in normalized:
                out[axis] = normalized[k]
                break
    if all(out.values()):
        return out
    return None


prepared = None
for path in csv_files:
    try:
        head = pd.read_csv(path, nrows=5)
    except Exception:
        continue
    axis_map = detect_axis_map(list(head.columns))
    if axis_map is None:
        continue

    df = pd.read_csv(path)
    x_col = axis_map['x']
    y_col = axis_map['y']
    z_col = axis_map['z']

    out = pd.DataFrame()

    col_map = {str(c).strip().lower(): c for c in df.columns}
    ts_col = None
    for key in ['timestamp', 'time', 'datetime', 'date']:
        if key in col_map:
            ts_col = col_map[key]
            break

    if ts_col is None:
        out['timestamp'] = range(len(df))
    else:
        out['timestamp'] = df[ts_col]

    out['x'] = pd.to_numeric(df[x_col], errors='coerce')
    out['y'] = pd.to_numeric(df[y_col], errors='coerce')
    out['z'] = pd.to_numeric(df[z_col], errors='coerce')

    out_csv = OUT_DIR / 'bearing_xyz.csv'
    out.to_csv(out_csv, index=False)
    prepared = out_csv
    print('prepared_csv:', out_csv, 'shape:', out.shape)
    break

if prepared is None:
    # Fallback: first npy with shape (T,3)
    for path in npy_files:
        arr = np.load(path)
        if arr.ndim == 2 and arr.shape[1] == 3:
            out_npy = OUT_DIR / 'bearing_xyz.npy'
            np.save(out_npy, arr.astype(np.float32))
            prepared = out_npy
            print('prepared_npy:', out_npy, 'shape:', arr.shape)
            break

if prepared is None:
    raise RuntimeError('Could not build standardized vibration input (x,y,z). Check dataset file schema.')

base_cfg = REPO_DIR / 'configs' / 'swinmae_ssl.yaml'
real_cfg = REPO_DIR / 'configs' / 'swinmae_ssl_real.yaml'
cfg = yaml.safe_load(base_cfg.read_text(encoding='utf-8'))

if str(prepared).endswith('.npy'):
    cfg['data']['source'] = 'npy'
    cfg['data']['path'] = '/content/AnomalyDetection/data/vib/*.npy'
else:
    cfg['data']['source'] = 'csv'
    cfg['data']['path'] = '/content/AnomalyDetection/data/vib/*.csv'

cfg['data']['timestamp_col'] = 'timestamp'
real_cfg.write_text(yaml.safe_dump(cfg, sort_keys=False), encoding='utf-8')

print('real_config_written:', real_cfg)
print('important: set data.fs in configs/swinmae_ssl_real.yaml to real sampling rate')
print('now run the training cell below; it uses *_real.yaml if present')


In [None]:
# DATA_CHECK_SWINMAE
from pathlib import Path
import numpy as np
import pandas as pd

vib_dir = Path('/content/AnomalyDetection/data/vib')
csv_files = sorted(vib_dir.glob('*.csv'))
npy_files = sorted(vib_dir.glob('*.npy'))

print('vib_dir:', vib_dir)
print('csv_count:', len(csv_files))
print('npy_count:', len(npy_files))

if not csv_files and not npy_files:
    raise FileNotFoundError(f'No CSV/NPY files found in {vib_dir}. Run the Kaggle download cell first.')

if csv_files:
    sample = csv_files[0]
    print('sample_csv:', sample)
    df = pd.read_csv(sample)
    print('shape:', df.shape)
    print('columns:', df.columns.tolist())
    print(df.head(3))

    req = {'x', 'y', 'z'}
    miss = [c for c in req if c not in df.columns]
    if miss:
        print('[warn] missing required axis columns:', miss)
    else:
        print('[ok] x,y,z columns exist')

    if 'timestamp' in df.columns:
        ts_dt = pd.to_datetime(df['timestamp'], errors='coerce')
        dt_sec = ts_dt.diff().dt.total_seconds().dropna()
        dt_sec = dt_sec[dt_sec > 0]

        if len(dt_sec) > 0:
            fs_est = 1.0 / float(dt_sec.median())
            print('estimated_fs_from_datetime_median_dt:', fs_est)
        else:
            ts_num = pd.to_numeric(df['timestamp'], errors='coerce')
            dt_num = ts_num.diff().dropna()
            dt_num = dt_num[dt_num > 0]
            if len(dt_num) > 0:
                fs_est = 1.0 / float(dt_num.median())
                print('estimated_fs_from_numeric_timestamp:', fs_est)
            else:
                print('[warn] could not estimate fs from timestamp')

if npy_files:
    sample_npy = npy_files[0]
    arr = np.load(sample_npy)
    print('sample_npy:', sample_npy)
    print('npy_shape:', arr.shape)
    if arr.ndim == 2 and arr.shape[1] == 3:
        print('[ok] npy shape is (T, 3)')
    else:
        print('[warn] expected npy shape (T, 3)')

print('next: set data.fs in configs/swinmae_ssl_real.yaml to the real sampling rate')


In [None]:
from pathlib import Path
import subprocess
import sys
import shlex


def run(cmd):
    cmd = [str(x) for x in cmd]
    print('$', ' '.join(shlex.quote(x) for x in cmd))
    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
    )

    assert proc.stdout is not None
    for line in proc.stdout:
        print(line, end='')

    code = proc.wait()
    if code != 0:
        raise RuntimeError(f"Command failed ({code}): {' '.join(cmd)}")


req = Path('requirements.txt')
if not req.exists():
    raise FileNotFoundError(
        f"requirements.txt not found in cwd={Path.cwd()}. Run bootstrap cell first or fix repo path."
    )

run([sys.executable, '-m', 'pip', 'install', '-U', 'pip'])
run([sys.executable, '-m', 'pip', 'install', '-r', str(req)])


In [None]:
import torch

print('torch:', torch.__version__)
print('cuda_available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('cuda_device_count:', torch.cuda.device_count())
    print('cuda_device_0:', torch.cuda.get_device_name(0))


In [None]:
import subprocess
import sys
from pathlib import Path

cfg = Path('configs/swinmae_ssl_real.yaml')
if not cfg.exists():
    print('[warn] real config not found, fallback to synthetic config')
    cfg = Path('configs/swinmae_ssl.yaml')

cmd = [sys.executable, '-m', 'trainers.train_swinmae_ssl', '--config', str(cfg)]
print('$', ' '.join(cmd))
result = subprocess.run(cmd)
if result.returncode != 0:
    raise RuntimeError(f"Training failed with exit code {result.returncode}")


In [None]:
from pathlib import Path

checkpoint_path = Path('checkpoints/swinmae_ssl.pt')
print('checkpoint_exists:', checkpoint_path.exists(), checkpoint_path)
assert checkpoint_path.exists(), f'Missing checkpoint: {checkpoint_path}'
print('checkpoint_size_bytes:', checkpoint_path.stat().st_size)
