Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,81 @@
import { buildOneTimePurchaseTransaction, buildSubscriptionTransaction, resolveSelectedPriceFromProduct } from "@/app/api/latest/internal/payments/transactions/transaction-builder";
import { getStripeForAccount } from "@/lib/stripe";
import { getPrismaClientForTenancy } from "@/prisma-client";
import { createSmartRouteHandler } from "@/route-handlers/smart-route-handler";
import type { TransactionEntry } from "@stackframe/stack-shared/dist/interface/crud/transactions";
import { KnownErrors } from "@stackframe/stack-shared/dist/known-errors";
import { adaptSchema, adminAuthTypeSchema, yupBoolean, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { StackAssertionError } from "@stackframe/stack-shared/dist/utils/errors";
import { SubscriptionStatus } from "@/generated/prisma/client";
import { adaptSchema, adminAuthTypeSchema, moneyAmountSchema, productSchema, yupArray, yupBoolean, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { moneyAmountToStripeUnits } from "@stackframe/stack-shared/dist/utils/currencies";
import { SUPPORTED_CURRENCIES, type MoneyAmount } from "@stackframe/stack-shared/dist/utils/currency-constants";
import { StackAssertionError, throwErr } from "@stackframe/stack-shared/dist/utils/errors";
import { InferType } from "yup";

const USD_CURRENCY = SUPPORTED_CURRENCIES.find((currency) => currency.code === "USD")
?? throwErr("USD currency configuration missing in SUPPORTED_CURRENCIES");

function getTotalUsdStripeUnits(options: { product: InferType<typeof productSchema>, priceId: string | null, quantity: number }) {
const selectedPrice = resolveSelectedPriceFromProduct(options.product, options.priceId ?? null);
const usdPrice = selectedPrice?.USD;
if (typeof usdPrice !== "string") {
throw new KnownErrors.SchemaError("Refund amounts can only be specified for USD-priced purchases.");
}
if (!Number.isFinite(options.quantity) || Math.trunc(options.quantity) !== options.quantity) {
throw new StackAssertionError("Purchase quantity is not an integer", { quantity: options.quantity });
}
return moneyAmountToStripeUnits(usdPrice as MoneyAmount, USD_CURRENCY) * options.quantity;
}

type RefundEntrySelection = {
entry_index: number,
quantity: number,
amount_usd: MoneyAmount,
};

function validateRefundEntries(options: { entries: TransactionEntry[], refundEntries: RefundEntrySelection[] }) {
const seenEntryIndexes = new Set<number>();
const entryByIndex = new Map<number, TransactionEntry>(
options.entries.map((entry, index) => [index, entry]),
);

for (const refundEntry of options.refundEntries) {
if (!Number.isFinite(refundEntry.quantity) || Math.trunc(refundEntry.quantity) !== refundEntry.quantity) {
throw new KnownErrors.SchemaError("Refund quantity must be an integer.");
}
if (refundEntry.quantity < 0) {
throw new KnownErrors.SchemaError("Refund quantity cannot be negative.");
}
if (seenEntryIndexes.has(refundEntry.entry_index)) {
throw new KnownErrors.SchemaError("Refund entries cannot contain duplicate entry indexes.");
}
seenEntryIndexes.add(refundEntry.entry_index);
const entry = entryByIndex.get(refundEntry.entry_index);
if (!entry) {
throw new KnownErrors.SchemaError("Refund entry index is invalid.");
}
if (entry.type !== "product_grant") {
throw new KnownErrors.SchemaError("Refund entries must reference product grant entries.");
}
if (refundEntry.quantity > entry.quantity) {
throw new KnownErrors.SchemaError("Refund quantity cannot exceed purchased quantity.");
}
}
}

function getRefundedQuantity(refundEntries: RefundEntrySelection[]) {
let total = 0;
for (const refundEntry of refundEntries) {
total += refundEntry.quantity;
}
return total;
}

function getRefundAmountStripeUnits(refundEntries: RefundEntrySelection[]) {
let total = 0;
for (const refundEntry of refundEntries) {
total += moneyAmountToStripeUnits(refundEntry.amount_usd, USD_CURRENCY);
}
return total;
}

export const POST = createSmartRouteHandler({
metadata: {
Expand All @@ -19,6 +90,13 @@ export const POST = createSmartRouteHandler({
body: yupObject({
type: yupString().oneOf(["subscription", "one-time-purchase"]).defined(),
id: yupString().defined(),
refund_entries: yupArray(
yupObject({
entry_index: yupNumber().integer().defined(),
quantity: yupNumber().integer().defined(),
amount_usd: moneyAmountSchema(USD_CURRENCY).defined(),
}).defined(),
).defined(),
}).defined()
}),
response: yupObject({
Expand All @@ -30,10 +108,13 @@ export const POST = createSmartRouteHandler({
}),
handler: async ({ auth, body }) => {
const prisma = await getPrismaClientForTenancy(auth.tenancy);
const refundEntries = body.refund_entries.map((entry) => ({
...entry,
amount_usd: entry.amount_usd as MoneyAmount,
}));
if (body.type === "subscription") {
const subscription = await prisma.subscription.findUnique({
where: { tenancyId_id: { tenancyId: auth.tenancy.id, id: body.id } },
select: { refundedAt: true },
});
if (!subscription) {
throw new KnownErrors.SubscriptionInvoiceNotFound(body.id);
Expand Down Expand Up @@ -72,16 +153,73 @@ export const POST = createSmartRouteHandler({
if (!paymentIntentId || typeof paymentIntentId !== "string") {
throw new StackAssertionError("Payment has no payment intent", { invoiceId: subscriptionInvoice.stripeInvoiceId });
}
await stripe.refunds.create({ payment_intent: paymentIntentId });
await prisma.subscription.update({
where: { tenancyId_id: { tenancyId: auth.tenancy.id, id: body.id } },
data: {
status: SubscriptionStatus.canceled,
cancelAtPeriodEnd: true,
currentPeriodEnd: new Date(),
refundedAt: new Date(),
},
let refundAmountStripeUnits: number | null = null;
const transaction = buildSubscriptionTransaction({ subscription });
validateRefundEntries({
entries: transaction.entries,
refundEntries,
});
Comment thread
BilalG1 marked this conversation as resolved.
const refundedQuantity = getRefundedQuantity(refundEntries);
const totalStripeUnits = getTotalUsdStripeUnits({
product: subscription.product as InferType<typeof productSchema>,
priceId: subscription.priceId ?? null,
quantity: subscription.quantity,
});
refundAmountStripeUnits = getRefundAmountStripeUnits(refundEntries);
if (refundAmountStripeUnits < 0) {
throw new KnownErrors.SchemaError("Refund amount cannot be negative.");
}
if (refundAmountStripeUnits > totalStripeUnits) {
throw new KnownErrors.SchemaError("Refund amount cannot exceed the charged amount.");
}
Comment thread
BilalG1 marked this conversation as resolved.
await stripe.refunds.create({
payment_intent: paymentIntentId,
amount: refundAmountStripeUnits,
});
if (refundedQuantity > 0) {
if (!subscription.stripeSubscriptionId) {
throw new StackAssertionError("Stripe subscription id missing for refund", { subscriptionId: subscription.id });
}
const stripeSubscription = await stripe.subscriptions.retrieve(subscription.stripeSubscriptionId);
if (stripeSubscription.items.data.length === 0) {
throw new StackAssertionError("Stripe subscription has no items", { subscriptionId: subscription.id });
}
const subscriptionItem = stripeSubscription.items.data[0];
if (!Number.isFinite(subscriptionItem.quantity) || Math.trunc(subscriptionItem.quantity ?? 0) !== subscriptionItem.quantity) {
throw new StackAssertionError("Stripe subscription item quantity is not an integer", {
subscriptionId: subscription.id,
itemQuantity: subscriptionItem.quantity,
});
}
const currentQuantity = subscriptionItem.quantity ?? 0;
const newQuantity = currentQuantity - refundedQuantity;
if (newQuantity < 0) {
throw new StackAssertionError("Refund quantity exceeds Stripe subscription item quantity", {
subscriptionId: subscription.id,
currentQuantity,
refundedQuantity,
});
}
await stripe.subscriptions.update(subscription.stripeSubscriptionId, {
cancel_at_period_end: newQuantity === 0,
items: [{
id: subscriptionItem.id,
quantity: newQuantity,
}],
});
await prisma.subscription.update({
where: { tenancyId_id: { tenancyId: auth.tenancy.id, id: body.id } },
data: {
cancelAtPeriodEnd: newQuantity === 0,
refundedAt: new Date(),
},
});
} else {
await prisma.subscription.update({
where: { tenancyId_id: { tenancyId: auth.tenancy.id, id: body.id } },
data: { refundedAt: new Date() },
});
}
Comment thread
BilalG1 marked this conversation as resolved.
} else {
const purchase = await prisma.oneTimePurchase.findUnique({
where: { tenancyId_id: { tenancyId: auth.tenancy.id, id: body.id } },
Expand All @@ -99,8 +237,27 @@ export const POST = createSmartRouteHandler({
if (!purchase.stripePaymentIntentId) {
throw new KnownErrors.OneTimePurchaseNotFound(body.id);
}
let refundAmountStripeUnits: number | null = null;
const transaction = buildOneTimePurchaseTransaction({ purchase });
validateRefundEntries({
entries: transaction.entries,
refundEntries,
});
const totalStripeUnits = getTotalUsdStripeUnits({
product: purchase.product as InferType<typeof productSchema>,
priceId: purchase.priceId ?? null,
quantity: purchase.quantity,
});
refundAmountStripeUnits = getRefundAmountStripeUnits(refundEntries);
if (refundAmountStripeUnits < 0) {
throw new KnownErrors.SchemaError("Refund amount cannot be negative.");
}
if (refundAmountStripeUnits > totalStripeUnits) {
throw new KnownErrors.SchemaError("Refund amount cannot exceed the charged amount.");
}
await stripe.refunds.create({
payment_intent: purchase.stripePaymentIntentId,
amount: refundAmountStripeUnits,
metadata: {
tenancyId: auth.tenancy.id,
purchaseId: purchase.id,
Expand Down
Loading