In [None]:
import re, numpy as np, scipy.sparse as sp
from collections import Counter
from sklearn.decomposition import TruncatedSVD
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score, f1_score, classification_report
from datasets import load_dataset
import warnings; warnings.filterwarnings("ignore")

# --- Tokenization & Preprocessing ---
def simple_tokenize(t):
    t = re.sub(r"[^a-z0-9'\s]", " ", t.lower())
    return [w for w in t.split() if len(w) > 1]
def preprocess(corpus):
    return [simple_tokenize(x) for x in corpus]

# --- Vocabulary & Domain Split ---
def build_vocab(src, tgt):
    s, t = Counter([w for d in src for w in d]), Counter([w for d in tgt for w in d])
    return s, t, set(s) | set(t)

def split_words(s, t, vocab, min_freq=5, ratio=5.0):
    di, ss, ts = set(), set(), set()
    for w in vocab:
        fs, ft = s.get(w, 0), t.get(w, 0)
        if fs + ft < min_freq: continue
        if fs > 0 and ft > 0:
            r = (fs + 1) / (ft + 1)
            if 1/ratio <= r <= ratio: di.add(w)
            elif r > ratio: ss.add(w)
            else: ts.add(w)
        elif fs > 0: ss.add(w)
        else: ts.add(w)
    return di, ss, ts

# --- Co-occurrence Matrix ---
def cooc(docs, spec, indep, win=5):
    sl, il = sorted(spec), sorted(indep)
    si, ii = {w:i for i,w in enumerate(sl)}, {w:i for i,w in enumerate(il)}
    r, c, v = [], [], []
    for doc in docs:
        n = len(doc)
        for i, w in enumerate(doc):
            if w in si:
                for u in doc[max(0, i - win):min(n, i + win + 1)]:
                    if u in ii: r.append(si[w]); c.append(ii[u]); v.append(1)
    return sp.csr_matrix((v, (r, c)), shape=(len(sl), len(il))), sl, il

# --- Spectral Feature Alignment ---
def sfa(M, dims=100):
    if M.shape[0] == 0 or M.shape[1] == 0:
        return np.zeros((M.shape[0], dims))
    svd = TruncatedSVD(n_components=min(dims, min(M.shape)-1 or 1), random_state=42)
    W = svd.fit_transform(M)
    W /= np.linalg.norm(W, axis=1, keepdims=True) + 1e-8
    return W

# --- Word Embeddings ---
def build_emb(spec, indep, W, M):
    M_dense = M.toarray() if M.nnz > 0 else np.zeros((len(spec), len(indep)))
    indep_emb = {w: ((W * (M_dense[:, j:j+1])).sum(0) / (M_dense[:, j:j+1].sum() or 1)) for j, w in enumerate(indep)}
    spec_emb = {w: W[i] for i, w in enumerate(spec)}
    return spec_emb, indep_emb

# --- Document Vectorization ---
def doc_vec(toks, spec_emb, indep_emb, dim=50):
    vecs = [spec_emb[w] for w in toks if w in spec_emb] + [indep_emb[w] for w in toks if w in indep_emb]
    return np.mean(vecs, 0) if vecs else np.zeros(dim)

# --- Full SFA Pipeline ---
def run_sfa(src_texts, src_labels, tgt_texts, tgt_labels, min_freq=5, dims=100, n=2000):
    tok_src, tok_tgt = preprocess(src_texts[:n]), preprocess(tgt_texts[:n])
    s, t, vocab = build_vocab(tok_src, tok_tgt)
    di, ss, ts = split_words(s, t, vocab, min_freq)
    all_docs = tok_src + tok_tgt

    M_s, sl_s, indep = cooc(all_docs, ss, di)
    M_t, sl_t, _ = cooc(all_docs, ts, di)
    W = sfa(sp.vstack([M_s, M_t]), dims)
    ns = M_s.shape[0]
    W_s, W_t = W[:ns], W[ns:]

    spec_s, indep_emb = build_emb(sl_s, indep, W_s, M_s)
    spec_t, _ = build_emb(sl_t, indep, W_t, M_t)
    spec_emb = {**spec_s, **spec_t}

    dim = W.shape[1]
    Xs = np.vstack([doc_vec(d, spec_emb, indep_emb, dim) for d in tok_src])
    Xt = np.vstack([doc_vec(d, spec_emb, indep_emb, dim) for d in tok_tgt])

    clf = LinearSVC(random_state=42, max_iter=5000).fit(Xs, src_labels[:n])
    y_pred = clf.predict(Xt)

    print("Accuracy:", round(accuracy_score(tgt_labels[:n], y_pred), 4))
    print("Macro-F1:", round(f1_score(tgt_labels[:n], y_pred, average='macro'), 4))
    print(classification_report(tgt_labels[:n], y_pred))

# --- Run on Amazon â†’ Yelp ---
amazon = load_dataset("amazon_polarity", split="train[:5000]")
yelp = load_dataset("yelp_polarity", split="train[:5000]")
run_sfa([x["content"] for x in amazon], [x["label"] for x in amazon],
        [x["text"] for x in yelp], [x["label"] for x in yelp],
        min_freq=3, dims=50, n=2000)
