In [None]:
import pandas as pd
import boto3

s3 = boto3.client("s3")
s3.download_file("sagemaker-us-east-1-084375569056", "data/train.csv", "train.csv")
df = pd.read_csv("train.csv")
if "Churn" not in df.columns:
    raise ValueError("Column 'Churn' not found.")
df["Churn"] = df["Churn"].map({"Yes": 1, "No": 0})
if "customerID" in df.columns:
    df = df.drop(columns=["customerID"])
categorical_cols = df.select_dtypes(include=["object"]).columns
for col in categorical_cols:
    df[col] = df[col].astype("category").cat.codes
for col in df.columns:
    if df[col].dtype == "object":
        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0)
columns = ["Churn"] + [col for col in df.columns if col != "Churn"]
df = df[columns]
df.to_csv("train_processed.csv", header=False, index=False)
s3.upload_file("train_processed.csv", "sagemaker-us-east-1-084375569056", "data/train_processed.csv")
print("Uploaded processed dataset to S3")

In [None]:
import sagemaker
from sagemaker.estimator import Estimator
from sagemaker.inputs import TrainingInput

session = sagemaker.Session()
estimator = Estimator(
    image_uri=sagemaker.image_uris.retrieve("xgboost", region="us-east-1", version="latest"),
    role="arn:aws:iam::084375569056:role/service-role/AmazonSageMaker-ExecutionRole-20250520T093901",
    instance_count=1,
    instance_type="ml.m5.large",
    output_path="s3://sagemaker-us-east-1-084375569056/output/",
    sagemaker_session=session
)
estimator.set_hyperparameters(
    max_depth=5,
    eta=0.2,
    gamma=4,
    min_child_weight=6,
    subsample=0.8,
    objective="binary:logistic",
    num_round=100
)
train_input = TrainingInput(
    s3_data="s3://sagemaker-us-east-1-084375569056/data/train_processed.csv",
    content_type="csv"
)
estimator.fit({"train": train_input})

In [None]:
from sagemaker.predictor import Predictor
from sagemaker.serializers import CSVSerializer

predictor = estimator.deploy(
    initial_instance_count=1,
    instance_type="ml.t2.medium",
    serializer=CSVSerializer()
)
predictor.delete_endpoint()