In [None]:
# Celula pentru a rula in colab modelul antrenat
import os
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification

#Mount Drive daca esti in Google Colab
try:
    from google.colab import drive
    drive.mount("/content/drive")
    IN_COLAB = True
except Exception:
    IN_COLAB = False

# Calea spre folderul unde ai config.json si weights (il puneti pe drive si schimbti pathul)
MODEL_PATH = "/content/drive/MyDrive/Colab Notebooks/PsihologAI/distilroberta_risk_level/model_cu_weights"

# Verificari
assert os.path.isdir(MODEL_PATH), f"Nu gasesc folderul modelului la: {MODEL_PATH}"
assert os.path.exists(os.path.join(MODEL_PATH, "config.json")), "Lipseste config.json din folderul modelului"

#Incarcam modelul antrenat + tokenizerul
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

MAX_LEN = 512
STRIDE = 128

# Praguri (ajusteaza dupa cum vezi rezultate)
TH_HIGH = 0.25   
TH_LOW  = 0.35   

def _id2label_safe(cfg):
    out = {}
    for k, v in cfg.id2label.items():
        out[int(k)] = v
    return out

ID2LABEL = _id2label_safe(model.config)
LABEL2ID = {v: k for k, v in ID2LABEL.items()}

assert "high_risk" in LABEL2ID, f"Nu gasesc 'high_risk' in etichete. Am: {list(LABEL2ID.keys())}"
assert "low_risk" in LABEL2ID, f"Nu gasesc 'low_risk' in etichete. Am: {list(LABEL2ID.keys())}"

HIGH_ID = LABEL2ID["high_risk"]
LOW_ID  = LABEL2ID["low_risk"]

def predict_risk_long_text(text):
    enc = tokenizer(
        text,
        truncation=True,
        max_length=MAX_LEN,
        return_overflowing_tokens=True,
        stride=STRIDE,
        padding="max_length",
        return_tensors="pt"
    )

    enc.pop("overflow_to_sample_mapping", None)
    enc.pop("num_truncated_tokens", None)

    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1)

    probs_doc = probs.mean(dim=0)

    p_high = float(probs_doc[HIGH_ID])
    p_low  = float(probs_doc[LOW_ID])

    if p_high >= TH_HIGH:
        pred_id = HIGH_ID
    elif p_low >= TH_LOW:
        pred_id = LOW_ID
    else:
        pred_id = int(torch.argmax(probs_doc).item())

    pred_label = ID2LABEL[pred_id]
    return pred_label, probs_doc.detach().cpu().numpy(), probs.shape[0]

print("Model incarcat din:", MODEL_PATH)
print("Clase:", [ID2LABEL[i] for i in range(len(ID2LABEL))])
print("TH_HIGH:", TH_HIGH, "| TH_LOW:", TH_LOW)

while True:
    input_TEXT = input("\nScrie un text pentru clasificare (sau 'exit' pentru iesire):\n")

    if input_TEXT.lower().strip() == "exit":
        print("Iesire din modul de testare.")
        break

    if len(input_TEXT.strip()) == 0:
        print("Text gol. Introdu un text valid.")
        continue

    label, scores, n_chunks = predict_risk_long_text(input_TEXT)

    print("\n--- Rezultat clasificare ---")
    print("Clasa prezisa:", label)
    print("Numar chunk-uri folosite:", n_chunks)
    print(f"Prob high_risk: {scores[HIGH_ID]:.4f}")
    print(f"Prob low_risk : {scores[LOW_ID]:.4f}")
    print("\nScoruri pe clase:")

    for i in range(len(scores)):
        print(f"  {ID2LABEL[i]}: {scores[i]:.4f}")

    print("-" * 50)


Mounted at /content/drive




Model incarcat din: /content/drive/MyDrive/Colab Notebooks/PsihologAI/distilroberta_risk_level/model_cu_weights
Clase: ['high_risk', 'low_risk', 'medium_risk']
TH_HIGH: 0.25 | TH_LOW: 0.35

Scrie un text pentru clasificare (sau 'exit' pentru iesire):
I am happy 

--- Rezultat clasificare ---
Clasa prezisa: medium_risk
Numar chunk-uri folosite: 1
Prob high_risk: 0.2377
Prob low_risk : 0.2037

Scoruri pe clase:
  high_risk: 0.2377
  low_risk: 0.2037
  medium_risk: 0.5586
--------------------------------------------------

Scrie un text pentru clasificare (sau 'exit' pentru iesire):
i am happy 

--- Rezultat clasificare ---
Clasa prezisa: high_risk
Numar chunk-uri folosite: 1
Prob high_risk: 0.3300
Prob low_risk : 0.1594

Scoruri pe clase:
  high_risk: 0.3300
  low_risk: 0.1594
  medium_risk: 0.5106
--------------------------------------------------

Scrie un text pentru clasificare (sau 'exit' pentru iesire):
i am happy for my sister

--- Rezultat clasificare ---
Clasa prezisa: low_risk
